diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index da47ec8..67e83c0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,10 +1,17 @@ name: Rust + +# run on push and pull request to main and release branches on: push: - branches: [ "main" ] + branches: + - main + - release/* pull_request: - branches: [ "main" ] + branches: + - main + - release/* + env: CARGO_TERM_COLOR: always @@ -27,7 +34,7 @@ jobs: # run clippy to verify we have no warnings - run: cargo fetch - name: cargo clippy - run: cargo clippy --all-targets --all-features -- -D warnings + run: cargo clippy -p mqrstt test: name: Test @@ -45,7 +52,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 - - uses: EmbarkStudios/cargo-deny-action@v1 + - uses: EmbarkStudios/cargo-deny-action@v2 coverage: name: Coverage @@ -77,4 +84,5 @@ jobs: - name: Upload coverage report uses: codecov/codecov-action@v3 with: + token: ${{ secrets.CODECOV_TOKEN }} files: ./lcov.txt diff --git a/.gitignore b/.gitignore index 9b5ed2a..0a11e9d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,7 @@ **/target examples/tokio_tls/Cargo.lock examples/smol_tls/Cargo.lock -.vscode/** \ No newline at end of file +.vscode/** + +Cargo.lock +test.py \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index a8b4739..0000000 --- a/Cargo.lock +++ /dev/null @@ -1,1944 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "addr2line" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - -[[package]] -name = "aho-corasick" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" -dependencies = [ - "memchr", -] - -[[package]] -name = "anes" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" - -[[package]] -name = "anstyle" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" - -[[package]] -name = "async-channel" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" -dependencies = [ - "concurrent-queue", - "event-listener 2.5.3", - "futures-core", -] - -[[package]] -name = "async-channel" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" -dependencies = [ - "concurrent-queue", - "event-listener 4.0.2", - "event-listener-strategy", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-executor" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17ae5ebefcc48e7452b4987947920dac9450be1110cadf34d1b8c116bdbaf97c" -dependencies = [ - "async-lock 3.2.0", - "async-task", - "concurrent-queue", - "fastrand 2.0.1", - "futures-lite 2.1.0", - "slab", -] - -[[package]] -name = "async-fs" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "279cf904654eeebfa37ac9bb1598880884924aab82e290aa65c9e77a0e142e06" -dependencies = [ - "async-lock 2.8.0", - "autocfg", - "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "async-fs" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd1f344136bad34df1f83a47f3fd7f2ab85d75cb8a940af4ccf6d482a84ea01b" -dependencies = [ - "async-lock 3.2.0", - "blocking", - "futures-lite 2.1.0", -] - -[[package]] -name = "async-io" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" -dependencies = [ - "async-lock 2.8.0", - "autocfg", - "cfg-if", - "concurrent-queue", - "futures-lite 1.13.0", - "log", - "parking", - "polling 2.8.0", - "rustix 0.37.27", - "slab", - "socket2 0.4.10", - "waker-fn", -] - -[[package]] -name = "async-io" -version = "2.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6afaa937395a620e33dc6a742c593c01aced20aa376ffb0f628121198578ccc7" -dependencies = [ - "async-lock 3.2.0", - "cfg-if", - "concurrent-queue", - "futures-io", - "futures-lite 2.1.0", - "parking", - "polling 3.3.1", - "rustix 0.38.28", - "slab", - "tracing", - "windows-sys 0.52.0", -] - -[[package]] -name = "async-lock" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" -dependencies = [ - "event-listener 2.5.3", -] - -[[package]] -name = "async-lock" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" -dependencies = [ - "event-listener 4.0.2", - "event-listener-strategy", - "pin-project-lite", -] - -[[package]] -name = "async-net" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0434b1ed18ce1cf5769b8ac540e33f01fa9471058b5e89da9e06f3c882a8c12f" -dependencies = [ - "async-io 1.13.0", - "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "async-net" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7" -dependencies = [ - "async-io 2.2.2", - "blocking", - "futures-lite 2.1.0", -] - -[[package]] -name = "async-process" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea6438ba0a08d81529c69b36700fa2f95837bfe3e776ab39cde9c14d9149da88" -dependencies = [ - "async-io 1.13.0", - "async-lock 2.8.0", - "async-signal", - "blocking", - "cfg-if", - "event-listener 3.1.0", - "futures-lite 1.13.0", - "rustix 0.38.28", - "windows-sys 0.48.0", -] - -[[package]] -name = "async-process" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15c1cd5d253ecac3d3cf15e390fd96bd92a13b1d14497d81abf077304794fb04" -dependencies = [ - "async-channel 2.1.1", - "async-io 2.2.2", - "async-lock 3.2.0", - "async-signal", - "blocking", - "cfg-if", - "event-listener 4.0.2", - "futures-lite 2.1.0", - "rustix 0.38.28", - "windows-sys 0.52.0", -] - -[[package]] -name = "async-rustls" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93b21a03b7c21702a0110f9f8d228763a533570deb376119042dabf33c37a01a" -dependencies = [ - "futures-io", - "rustls 0.20.9", - "webpki", -] - -[[package]] -name = "async-rustls" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd10f063fb367d26334e10c50c67ea31ac542b8c3402be2251db4cfc5d74ba66" -dependencies = [ - "futures-io", - "rustls 0.21.10", -] - -[[package]] -name = "async-signal" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5" -dependencies = [ - "async-io 2.2.2", - "async-lock 2.8.0", - "atomic-waker", - "cfg-if", - "futures-core", - "futures-io", - "rustix 0.38.28", - "signal-hook-registry", - "slab", - "windows-sys 0.48.0", -] - -[[package]] -name = "async-task" -version = "4.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d90cd0b264dfdd8eb5bad0a2c217c1f88fa96a8573f40e7b12de23fb468f46" - -[[package]] -name = "async-trait" -version = "0.1.77" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "backtrace" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - -[[package]] -name = "base64" -version = "0.21.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" - -[[package]] -name = "blocking" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" -dependencies = [ - "async-channel 2.1.1", - "async-lock 3.2.0", - "async-task", - "fastrand 2.0.1", - "futures-io", - "futures-lite 2.1.0", - "piper", - "tracing", -] - -[[package]] -name = "bumpalo" -version = "3.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" - -[[package]] -name = "bytes" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" - -[[package]] -name = "cast" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" - -[[package]] -name = "cc" -version = "1.0.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "ciborium" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" - -[[package]] -name = "ciborium-ll" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" -dependencies = [ - "ciborium-io", - "half", -] - -[[package]] -name = "clap" -version = "4.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfab8ba68f3668e89f6ff60f5b205cea56aa7b769451a59f34b8682f51c056d" -dependencies = [ - "clap_builder", -] - -[[package]] -name = "clap_builder" -version = "4.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb7fb5e4e979aec3be7791562fcba452f94ad85e954da024396433e0e25a79e9" -dependencies = [ - "anstyle", - "clap_lex", -] - -[[package]] -name = "clap_lex" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" - -[[package]] -name = "concurrent-queue" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "criterion" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" -dependencies = [ - "anes", - "cast", - "ciborium", - "clap", - "criterion-plot", - "futures", - "is-terminal", - "itertools", - "num-traits", - "once_cell", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "tokio", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" -dependencies = [ - "cast", - "itertools", -] - -[[package]] -name = "crossbeam-deque" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" -dependencies = [ - "cfg-if", - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" -dependencies = [ - "autocfg", - "cfg-if", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - -[[package]] -name = "errno" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "event-listener" -version = "2.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" - -[[package]] -name = "event-listener" -version = "3.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93877bcde0eb80ca09131a08d23f0a5c18a620b01db137dba666d18cd9b30c2" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener" -version = "4.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "218a870470cce1469024e9fb66b901aa983929d81304a1cdb299f28118e550d5" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" -dependencies = [ - "event-listener 4.0.2", - "pin-project-lite", -] - -[[package]] -name = "fastrand" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] - -[[package]] -name = "fastrand" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" - -[[package]] -name = "futures" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" - -[[package]] -name = "futures-executor" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-io" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" - -[[package]] -name = "futures-lite" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" -dependencies = [ - "fastrand 1.9.0", - "futures-core", - "futures-io", - "memchr", - "parking", - "pin-project-lite", - "waker-fn", -] - -[[package]] -name = "futures-lite" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aeee267a1883f7ebef3700f262d2d54de95dfaf38189015a74fdc4e0c7ad8143" -dependencies = [ - "fastrand 2.0.1", - "futures-core", - "futures-io", - "parking", - "pin-project-lite", -] - -[[package]] -name = "futures-macro" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "futures-sink" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" - -[[package]] -name = "futures-task" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" - -[[package]] -name = "futures-timer" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" - -[[package]] -name = "futures-util" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "slab", -] - -[[package]] -name = "getrandom" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "gimli" -version = "0.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" - -[[package]] -name = "glob" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" - -[[package]] -name = "half" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" - -[[package]] -name = "hermit-abi" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" - -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.48.0", -] - -[[package]] -name = "is-terminal" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" -dependencies = [ - "hermit-abi", - "rustix 0.38.28", - "windows-sys 0.52.0", -] - -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - -[[package]] -name = "itoa" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" - -[[package]] -name = "js-sys" -version = "0.3.66" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" -dependencies = [ - "wasm-bindgen", -] - -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - -[[package]] -name = "libc" -version = "0.2.151" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" - -[[package]] -name = "linux-raw-sys" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" - -[[package]] -name = "linux-raw-sys" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" - -[[package]] -name = "log" -version = "0.4.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" - -[[package]] -name = "matchers" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" -dependencies = [ - "regex-automata 0.1.10", -] - -[[package]] -name = "memchr" -version = "2.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" - -[[package]] -name = "miniz_oxide" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" -dependencies = [ - "adler", -] - -[[package]] -name = "mio" -version = "0.8.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" -dependencies = [ - "libc", - "wasi", - "windows-sys 0.48.0", -] - -[[package]] -name = "mqrstt" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a854e678a3a205f8ac238694a8aae480684784f604a29367e47bd93a67bad83c" -dependencies = [ - "async-channel 1.9.0", - "async-trait", - "bytes", - "futures", - "smol 1.3.0", - "thiserror", - "tokio", -] - -[[package]] -name = "mqrstt" -version = "0.3.0" -dependencies = [ - "async-channel 2.1.1", - "async-rustls 0.4.1", - "bytes", - "criterion", - "futures", - "rand", - "rstest", - "rustls 0.21.10", - "rustls-pemfile", - "smol 2.0.0", - "thiserror", - "tokio", - "tokio-rustls", - "tracing", - "tracing-subscriber", - "webpki", -] - -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - -[[package]] -name = "num-traits" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "object" -version = "0.32.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" -dependencies = [ - "memchr", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "oorandom" -version = "11.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" - -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - -[[package]] -name = "parking" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" - -[[package]] -name = "pin-project-lite" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "piper" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" -dependencies = [ - "atomic-waker", - "fastrand 2.0.1", - "futures-io", -] - -[[package]] -name = "plotters" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" - -[[package]] -name = "plotters-svg" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" -dependencies = [ - "plotters-backend", -] - -[[package]] -name = "polling" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" -dependencies = [ - "autocfg", - "bitflags 1.3.2", - "cfg-if", - "concurrent-queue", - "libc", - "log", - "pin-project-lite", - "windows-sys 0.48.0", -] - -[[package]] -name = "polling" -version = "3.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf63fa624ab313c11656b4cda960bfc46c410187ad493c41f6ba2d8c1e991c9e" -dependencies = [ - "cfg-if", - "concurrent-queue", - "pin-project-lite", - "rustix 0.38.28", - "tracing", - "windows-sys 0.52.0", -] - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "proc-macro2" -version = "1.0.74" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2de98502f212cfcea8d0bb305bd0f49d7ebdd75b64ba0a68f937d888f4e0d6db" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rayon" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - -[[package]] -name = "regex" -version = "1.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", -] - -[[package]] -name = "regex-automata" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax 0.8.2", -] - -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - -[[package]] -name = "regex-syntax" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" - -[[package]] -name = "relative-path" -version = "1.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc" - -[[package]] -name = "ring" -version = "0.16.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" -dependencies = [ - "cc", - "libc", - "once_cell", - "spin 0.5.2", - "untrusted 0.7.1", - "web-sys", - "winapi", -] - -[[package]] -name = "ring" -version = "0.17.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" -dependencies = [ - "cc", - "getrandom", - "libc", - "spin 0.9.8", - "untrusted 0.9.0", - "windows-sys 0.48.0", -] - -[[package]] -name = "rstest" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" -dependencies = [ - "futures", - "futures-timer", - "rstest_macros", - "rustc_version", -] - -[[package]] -name = "rstest_macros" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" -dependencies = [ - "cfg-if", - "glob", - "proc-macro2", - "quote", - "regex", - "relative-path", - "rustc_version", - "syn", - "unicode-ident", -] - -[[package]] -name = "rustc-demangle" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" - -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - -[[package]] -name = "rustix" -version = "0.37.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" -dependencies = [ - "bitflags 1.3.2", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", -] - -[[package]] -name = "rustix" -version = "0.38.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" -dependencies = [ - "bitflags 2.4.1", - "errno", - "libc", - "linux-raw-sys 0.4.12", - "windows-sys 0.52.0", -] - -[[package]] -name = "rustls" -version = "0.20.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" -dependencies = [ - "log", - "ring 0.16.20", - "sct", - "webpki", -] - -[[package]] -name = "rustls" -version = "0.21.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" -dependencies = [ - "log", - "ring 0.17.7", - "rustls-webpki", - "sct", -] - -[[package]] -name = "rustls-pemfile" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" -dependencies = [ - "base64", -] - -[[package]] -name = "rustls-webpki" -version = "0.101.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" -dependencies = [ - "ring 0.17.7", - "untrusted 0.9.0", -] - -[[package]] -name = "ryu" -version = "1.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" - -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "sct" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" -dependencies = [ - "ring 0.17.7", - "untrusted 0.9.0", -] - -[[package]] -name = "semver" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" - -[[package]] -name = "serde" -version = "1.0.194" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b114498256798c94a0689e1a15fec6005dee8ac1f41de56404b67afc2a4b773" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.194" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3385e45322e8f9931410f01b3031ec534c3947d0e94c18049af4d9f9907d4e0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.110" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fbd975230bada99c8bb618e0c365c2eefa219158d5c6c29610fd09ff1833257" -dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - -[[package]] -name = "signal-hook-registry" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" -dependencies = [ - "libc", -] - -[[package]] -name = "slab" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] - -[[package]] -name = "smallvec" -version = "1.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" - -[[package]] -name = "smol" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13f2b548cd8447f8de0fdf1c592929f70f4fc7039a05e47404b0d096ec6987a1" -dependencies = [ - "async-channel 1.9.0", - "async-executor", - "async-fs 1.6.0", - "async-io 1.13.0", - "async-lock 2.8.0", - "async-net 1.8.0", - "async-process 1.8.1", - "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "smol" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e635339259e51ef85ac7aa29a1cd991b957047507288697a690e80ab97d07cad" -dependencies = [ - "async-channel 2.1.1", - "async-executor", - "async-fs 2.1.0", - "async-io 2.2.2", - "async-lock 3.2.0", - "async-net 2.0.0", - "async-process 2.0.1", - "blocking", - "futures-lite 2.1.0", -] - -[[package]] -name = "smol_tcp_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-rustls 0.3.0", - "async-trait", - "futures", - "mqrstt 0.2.2", - "rustls 0.20.9", - "rustls-pemfile", - "smol 1.3.0", - "webpki", -] - -[[package]] -name = "smol_tls_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-rustls 0.3.0", - "async-trait", - "futures", - "mqrstt 0.2.2", - "rustls 0.20.9", - "rustls-pemfile", - "smol 1.3.0", - "webpki", -] - -[[package]] -name = "socket2" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", -] - -[[package]] -name = "socket2" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" -dependencies = [ - "libc", - "windows-sys 0.48.0", -] - -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - -[[package]] -name = "syn" -version = "2.0.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89456b690ff72fddcecf231caedbe615c59480c93358a93dfae7fc29e3ebbf0e" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "sync_tcp_v0_2_2" -version = "0.1.0" -dependencies = [ - "mqrstt 0.2.2", -] - -[[package]] -name = "thiserror" -version = "1.0.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thread_local" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" -dependencies = [ - "cfg-if", - "once_cell", -] - -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "tokio" -version = "1.35.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" -dependencies = [ - "backtrace", - "bytes", - "libc", - "mio", - "num_cpus", - "pin-project-lite", - "socket2 0.5.5", - "tokio-macros", - "windows-sys 0.48.0", -] - -[[package]] -name = "tokio-macros" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-rustls" -version = "0.24.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" -dependencies = [ - "rustls 0.21.10", - "tokio", -] - -[[package]] -name = "tokio_tcp_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-trait", - "mqrstt 0.2.2", - "rustls-pemfile", - "tokio", - "tokio-rustls", - "webpki", -] - -[[package]] -name = "tokio_tls_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-trait", - "mqrstt 0.2.2", - "rustls 0.20.9", - "rustls-pemfile", - "tokio", - "tokio-rustls", - "webpki", -] - -[[package]] -name = "tracing" -version = "0.1.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" -dependencies = [ - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tracing-core" -version = "0.1.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" -dependencies = [ - "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" -dependencies = [ - "matchers", - "nu-ansi-term", - "once_cell", - "regex", - "sharded-slab", - "smallvec", - "thread_local", - "tracing", - "tracing-core", - "tracing-log", -] - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "untrusted" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "valuable" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" - -[[package]] -name = "waker-fn" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" - -[[package]] -name = "walkdir" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" -dependencies = [ - "same-file", - "winapi-util", -] - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" - -[[package]] -name = "web-sys" -version = "0.3.66" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "webpki" -version = "0.22.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" -dependencies = [ - "ring 0.17.7", - "untrusted 0.9.0", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" -dependencies = [ - "winapi", -] - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets 0.48.5", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.0", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - -[[package]] -name = "windows-targets" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" -dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" diff --git a/Cargo.toml b/Cargo.toml index 8c39591..9b444dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,5 +2,6 @@ members = [ "mqrstt", - "examples/*", + # "fuzz", + "examples/tcp" ] \ No newline at end of file diff --git a/deny.toml b/deny.toml index 7e10246..acabff1 100644 --- a/deny.toml +++ b/deny.toml @@ -1,18 +1,11 @@ -[advisories] -vulnerability = "deny" -unsound = "deny" -unmaintained = "deny" + [licenses] -unlicensed = "deny" -allow-osi-fsf-free = "neither" -copyleft = "deny" confidence-threshold = 0.95 -allow = ["MPL-2.0", "Apache-2.0", "MIT", "BSD-3-Clause", "ISC"] +allow = ["MPL-2.0", "Apache-2.0", "MIT", "Unicode-3.0"] exceptions = [ { allow = ["Unicode-DFS-2016"], name = "unicode-ident" }, - { allow = ["OpenSSL"], name = "ring" } ] [[licenses.clarify]] diff --git a/examples/.gitignore b/examples/.gitignore deleted file mode 100644 index ce9f139..0000000 --- a/examples/.gitignore +++ /dev/null @@ -1 +0,0 @@ -**/Cargo.lock \ No newline at end of file diff --git a/examples/smol_tcp_v0.2.2/Cargo.toml b/examples/smol_tcp_v0.2.2/Cargo.toml deleted file mode 100644 index 33ec470..0000000 --- a/examples/smol_tcp_v0.2.2/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "smol_tcp_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["smol"]} - -smol = { version = "1.3.0" } -futures = { version = "0.3.27", default-features = false, features = ["std", "async-await"] } - -async-trait = "0.1.68" - -rustls = { version = "0.20.7" } -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } -async-rustls = { version = "0.3.0" } diff --git a/examples/smol_tcp_v0.2.2/src/main.rs b/examples/smol_tcp_v0.2.2/src/main.rs deleted file mode 100644 index 5254322..0000000 --- a/examples/smol_tcp_v0.2.2/src/main.rs +++ /dev/null @@ -1,68 +0,0 @@ -use async_trait::async_trait; -use mqrstt::{ - new_smol, - packets::{self, Packet}, - smol::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -fn main() { - smol::block_on(async { - let client_id = "SmolTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_smol(options); - - let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(stream, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = futures::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); - }); -} diff --git a/examples/smol_tls_v0.2.2/Cargo.toml b/examples/smol_tls_v0.2.2/Cargo.toml deleted file mode 100644 index e5402d5..0000000 --- a/examples/smol_tls_v0.2.2/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "smol_tls_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["smol"]} - -smol = { version = "1.3.0" } -futures = { version = "0.3.27", default-features = false, features = ["std", "async-await"] } - -async-trait = "0.1.68" - -rustls = { version = "0.20.7" } -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } -async-rustls = { version = "0.3.0" } diff --git a/examples/smol_tls_v0.2.2/src/broker.emqx.io-ca.crt b/examples/smol_tls_v0.2.2/src/broker.emqx.io-ca.crt deleted file mode 100644 index fd4341d..0000000 --- a/examples/smol_tls_v0.2.2/src/broker.emqx.io-ca.crt +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIQCDvgVpBCRrGhdWrJWZHHSjANBgkqhkiG9w0BAQUFADBh -MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 -d3cuZGlnaWNlcnQuY29tMSAwHgYDVQQDExdEaWdpQ2VydCBHbG9iYWwgUm9vdCBD -QTAeFw0wNjExMTAwMDAwMDBaFw0zMTExMTAwMDAwMDBaMGExCzAJBgNVBAYTAlVT -MRUwEwYDVQQKEwxEaWdpQ2VydCBJbmMxGTAXBgNVBAsTEHd3dy5kaWdpY2VydC5j -b20xIDAeBgNVBAMTF0RpZ2lDZXJ0IEdsb2JhbCBSb290IENBMIIBIjANBgkqhkiG -9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4jvhEXLeqKTTo1eqUKKPC3eQyaKl7hLOllsB -CSDMAZOnTjC3U/dDxGkAV53ijSLdhwZAAIEJzs4bg7/fzTtxRuLWZscFs3YnFo97 -nh6Vfe63SKMI2tavegw5BmV/Sl0fvBf4q77uKNd0f3p4mVmFaG5cIzJLv07A6Fpt -43C/dxC//AH2hdmoRBBYMql1GNXRor5H4idq9Joz+EkIYIvUX7Q6hL+hqkpMfT7P -T19sdl6gSzeRntwi5m3OFBqOasv+zbMUZBfHWymeMr/y7vrTC0LUq7dBMtoM1O/4 -gdW7jVg/tRvoSSiicNoxBN33shbyTApOB6jtSj1etX+jkMOvJwIDAQABo2MwYTAO -BgNVHQ8BAf8EBAMCAYYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUA95QNVbR -TLtm8KPiGxvDl7I90VUwHwYDVR0jBBgwFoAUA95QNVbRTLtm8KPiGxvDl7I90VUw -DQYJKoZIhvcNAQEFBQADggEBAMucN6pIExIK+t1EnE9SsPTfrgT1eXkIoyQY/Esr -hMAtudXH/vTBH1jLuG2cenTnmCmrEbXjcKChzUyImZOMkXDiqw8cvpOp/2PV5Adg -06O/nVsJ8dWO41P0jmP6P6fbtGbfYmbW0W5BjfIttep3Sp+dWOIrWcBAI+0tKIJF -PnlUkiaY4IBIqDfv8NZ5YBberOgOzW6sRBc4L0na4UU+Krk2U886UAb3LujEV0ls -YSEY1QSteDwsOoBrp+uvFRTp2InBuThs4pFsiv9kuXclVzDAGySj4dzp30d8tbQk -CAUw7C29C79Fv1C5qfPrmAESrciIxpg0X40KPMbp1ZWVbd4= ------END CERTIFICATE----- diff --git a/examples/smol_tls_v0.2.2/src/main.rs b/examples/smol_tls_v0.2.2/src/main.rs deleted file mode 100644 index ee8e754..0000000 --- a/examples/smol_tls_v0.2.2/src/main.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::{ - io::{BufReader, Cursor}, - sync::Arc, -}; - -use async_trait::async_trait; -use mqrstt::{ - new_smol, - packets::{self, Packet}, - smol::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; -use rustls::{Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; - -pub const EMQX_CERT: &[u8] = include_bytes!("broker.emqx.io-ca.crt"); - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -#[derive(Debug, Clone)] -pub enum PrivateKey { - RSA(Vec), - ECC(Vec), -} - -pub fn simple_rust_tls(ca: Vec, alpn: Option>>, client_auth: Option<(Vec, PrivateKey)>) -> Result, rustls::Error> { - let mut root_cert_store = RootCertStore::empty(); - - let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca))).unwrap(); - - let trust_anchors = ca_certs.iter().map_while(|cert| { - if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { - Some(OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)) - } else { - None - } - }); - root_cert_store.add_server_trust_anchors(trust_anchors); - - assert!(!root_cert_store.is_empty()); - - let config = ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store); - - let mut config = match client_auth { - Some((client_cert_info, client_private_info)) => { - let read_private_keys = match client_private_info { - PrivateKey::RSA(rsa) => rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))), - PrivateKey::ECC(ecc) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))), - } - .unwrap(); - - let key = read_private_keys.into_iter().next().unwrap(); - - let client_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info))).unwrap(); - let client_cert_chain = client_certs.into_iter().map(Certificate).collect(); - - config.with_single_cert(client_cert_chain, rustls::PrivateKey(key))? - } - None => config.with_no_client_auth(), - }; - - if let Some(alpn) = alpn { - config.alpn_protocols.extend(alpn) - } - - Ok(Arc::new(config)) -} - -fn main() { - smol::block_on(async { - let client_id = "SmolTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_smol(options); - - let arc_client_config = simple_rust_tls(EMQX_CERT.to_vec(), None, None).unwrap(); - - let domain = ServerName::try_from(address).unwrap(); - let connector = async_rustls::TlsConnector::from(arc_client_config); - - let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); - let connection = connector.connect(domain, stream).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(connection, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = futures::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); - }); -} diff --git a/examples/sync_tcp_v0.2.2/Cargo.toml b/examples/sync_tcp_v0.2.2/Cargo.toml deleted file mode 100644 index c06ba83..0000000 --- a/examples/sync_tcp_v0.2.2/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "sync_tcp_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["sync"]} diff --git a/examples/sync_tcp_v0.2.2/src/main.rs b/examples/sync_tcp_v0.2.2/src/main.rs deleted file mode 100644 index 4abb3f1..0000000 --- a/examples/sync_tcp_v0.2.2/src/main.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::time::Duration; - -use mqrstt::{ - new_sync, packets::{self, Packet}, sync::NetworkStatus, ConnectOptions, EventHandler, MqttClient -}; - -pub struct PingPong { - pub client: MqttClient, -} - -impl EventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish_blocking(p.topic.clone(), p.qos, p.retain, "pong").unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -fn main() { - let client_id = "SyncTcp_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 1883; - - let (mut network, client) = new_sync(options); - - let stream = std::net::TcpStream::connect((address, port)).unwrap(); - stream.set_nonblocking(true).unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(stream, &mut pingpong).unwrap(); - - client.subscribe_blocking("mqrstt").unwrap(); - - let thread = std::thread::spawn(move || { - loop { - match network.poll(&mut pingpong) { - // The client is active but there is no data to be read - Ok(NetworkStatus::ActivePending) => std::thread::sleep(Duration::from_millis(100)), - // The client is active and there is data to be read - Ok(NetworkStatus::ActiveReady) => continue, - // The rest is an error - otherwise => return otherwise, - }; - } - }); - - std::thread::sleep(std::time::Duration::from_secs(30)); - client.disconnect_blocking().unwrap(); - - // Unwrap possible join errors on the thread. - let n = thread.join().unwrap(); - assert!(n.is_ok()); -} diff --git a/examples/tcp/Cargo.toml b/examples/tcp/Cargo.toml new file mode 100644 index 0000000..15e0444 --- /dev/null +++ b/examples/tcp/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "tcp" +version = "0.1.0" +edition = "2021" +license = "MIT" + +[dependencies] +smol = { version = "2" } +futures = "0.3.31" + +tokio = { version = "1", features = ["full"] } + +mqrstt = { path = "../../mqrstt", features = ["logs"] } + +[[bin]] +name = "tokio" +path = "src/tokio.rs" + +[[bin]] +name = "ping_pong" +path = "src/ping_pong.rs" + +[[bin]] +name = "ping_pong_smol" +path = "src/ping_pong_smol.rs" + +[[bin]] +name = "smol" +path = "src/smol.rs" diff --git a/examples/tcp/src/ping_pong.rs b/examples/tcp/src/ping_pong.rs new file mode 100644 index 0000000..3081554 --- /dev/null +++ b/examples/tcp/src/ping_pong.rs @@ -0,0 +1,54 @@ +use mqrstt::{ + packets::{self, Packet}, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, +}; +use tokio::time::Duration; + +pub struct PingPong { + pub client: MqttClient, +} +impl AsyncEventHandler for PingPong { + // Handlers only get INCOMING packets. + async fn handle(&mut self, event: packets::Packet) { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); + println!("Received Ping, Send pong!"); + } + } + } + Packet::ConnAck(_) => { + println!("Connected!") + } + _ => (), + } + } +} + +#[tokio::main] +async fn main() { + let (mut network, client) = NetworkBuilder::new_from_client_id("TokioTcpPingPongExample").tokio_network(); + + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + let stream = tokio::io::BufStream::new(stream); + + let mut pingpong = PingPong { client: client.clone() }; + + network.connect(stream, &mut pingpong).await.unwrap(); + + client.subscribe("mqrstt").await.unwrap(); + + let network_handle = tokio::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + tokio::time::sleep(Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = network_handle.await.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); +} diff --git a/examples/tcp/src/ping_pong_smol.rs b/examples/tcp/src/ping_pong_smol.rs new file mode 100644 index 0000000..9af87bd --- /dev/null +++ b/examples/tcp/src/ping_pong_smol.rs @@ -0,0 +1,52 @@ +use mqrstt::{ + packets::{self, Packet}, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, +}; +pub struct PingPong { + pub client: MqttClient, +} +impl AsyncEventHandler for PingPong { + // Handlers only get INCOMING packets. This can change later. + async fn handle(&mut self, event: packets::Packet) { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); + println!("Received Ping, Send pong!"); + } + } + } + Packet::ConnAck(_) => { + println!("Connected!") + } + _ => (), + } + } +} +fn main() { + smol::block_on(async { + let (mut network, client) = NetworkBuilder::new_from_client_id("mqrsttSmolExample").smol_network(); + let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + + let mut pingpong = PingPong { client: client.clone() }; + + network.connect(stream, &mut pingpong).await.unwrap(); + + // This subscribe is only processed when we run the network + client.subscribe("mqrstt").await.unwrap(); + + let task_handle = smol::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + smol::Timer::after(std::time::Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = task_handle.await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); + }); +} diff --git a/examples/tcp/src/smol.rs b/examples/tcp/src/smol.rs new file mode 100644 index 0000000..592d880 --- /dev/null +++ b/examples/tcp/src/smol.rs @@ -0,0 +1,42 @@ +use mqrstt::AsyncEventHandler; + +pub struct Handler { + byte_count: u64, +} + +impl AsyncEventHandler for Handler { + fn handle(&mut self, incoming_packet: mqrstt::packets::Packet) -> impl std::future::Future + Send + Sync { + async move { + if let mqrstt::packets::Packet::Publish(publish) = incoming_packet { + self.byte_count += publish.payload.len() as u64; + } + } + } +} + +fn main() { + smol::block_on(async { + let hostname = "broker.emqx.io:1883"; + + let mut handler = Handler { byte_count: 0 }; + + let stream = smol::net::TcpStream::connect(hostname).await.unwrap(); + let (mut network, client) = mqrstt::NetworkBuilder::new_from_client_id("TestClientABCDEFG").smol_network(); + + network.connect(stream, &mut handler).await.unwrap(); + smol::Timer::after(std::time::Duration::from_secs(5)).await; + + client.subscribe("testtopic/#").await.unwrap(); + + smol::spawn(async move { + network.run(&mut handler).await.unwrap(); + + dbg!(handler.byte_count); + }) + .detach(); + + smol::Timer::after(std::time::Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); + smol::Timer::after(std::time::Duration::from_secs(1)).await; + }); +} diff --git a/examples/tcp/src/tokio.rs b/examples/tcp/src/tokio.rs new file mode 100644 index 0000000..e3db001 --- /dev/null +++ b/examples/tcp/src/tokio.rs @@ -0,0 +1,41 @@ +use mqrstt::AsyncEventHandler; + +pub struct Handler { + byte_count: u64, +} + +impl AsyncEventHandler for Handler { + fn handle(&mut self, incoming_packet: mqrstt::packets::Packet) -> impl std::future::Future + Send + Sync { + async move { + if let mqrstt::packets::Packet::Publish(publish) = incoming_packet { + self.byte_count += publish.payload.len() as u64; + } + } + } +} + +#[tokio::main] +async fn main() { + let hostname = "broker.emqx.io:1883"; + + let mut handler = Handler { byte_count: 0 }; + + let stream = tokio::net::TcpStream::connect(hostname).await.unwrap(); + let stream = tokio::io::BufStream::new(stream); + let (mut network, client) = mqrstt::NetworkBuilder::new_from_client_id("TestClientABCDEFG").tokio_network(); + + network.connect(stream, &mut handler).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + + client.subscribe("testtopic/#").await.unwrap(); + + tokio::spawn(async move { + network.run(&mut handler).await.unwrap(); + + dbg!(handler.byte_count); + }); + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; +} diff --git a/examples/tokio_tcp_v0.2.2/Cargo.toml b/examples/tokio_tcp_v0.2.2/Cargo.toml deleted file mode 100644 index 6a5954e..0000000 --- a/examples/tokio_tcp_v0.2.2/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "tokio_tcp_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["tokio"]} - -tokio = { version = "1.26.0", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } -tokio-rustls = "0.24.0" - -async-trait = "0.1.68" - -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } diff --git a/examples/tokio_tcp_v0.2.2/src/main.rs b/examples/tokio_tcp_v0.2.2/src/main.rs deleted file mode 100644 index 4e98b94..0000000 --- a/examples/tokio_tcp_v0.2.2/src/main.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::time::Duration; - -use async_trait::async_trait; -use mqrstt::{ - new_tokio, - packets::{self, Packet}, - tokio::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -#[tokio::main] -async fn main() { - let client_id = "TokioTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_tokio(options); - - let stream = tokio::net::TcpStream::connect((address, port)).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(stream, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = tokio::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - tokio::time::sleep(Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); -} diff --git a/examples/tokio_tls_v0.2.2/Cargo.toml b/examples/tokio_tls_v0.2.2/Cargo.toml deleted file mode 100644 index 2890bf9..0000000 --- a/examples/tokio_tls_v0.2.2/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "tokio_tls_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["tokio"]} - -tokio = { version = "1.26.0", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } -tokio-rustls = "0.24.0" - -async-trait = "0.1.68" - -rustls = { version = "0.20.7" } -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } diff --git a/examples/tokio_tls_v0.2.2/src/broker.emqx.io-ca.crt b/examples/tokio_tls_v0.2.2/src/broker.emqx.io-ca.crt deleted file mode 100644 index fd4341d..0000000 --- a/examples/tokio_tls_v0.2.2/src/broker.emqx.io-ca.crt +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIQCDvgVpBCRrGhdWrJWZHHSjANBgkqhkiG9w0BAQUFADBh -MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 -d3cuZGlnaWNlcnQuY29tMSAwHgYDVQQDExdEaWdpQ2VydCBHbG9iYWwgUm9vdCBD -QTAeFw0wNjExMTAwMDAwMDBaFw0zMTExMTAwMDAwMDBaMGExCzAJBgNVBAYTAlVT -MRUwEwYDVQQKEwxEaWdpQ2VydCBJbmMxGTAXBgNVBAsTEHd3dy5kaWdpY2VydC5j -b20xIDAeBgNVBAMTF0RpZ2lDZXJ0IEdsb2JhbCBSb290IENBMIIBIjANBgkqhkiG -9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4jvhEXLeqKTTo1eqUKKPC3eQyaKl7hLOllsB -CSDMAZOnTjC3U/dDxGkAV53ijSLdhwZAAIEJzs4bg7/fzTtxRuLWZscFs3YnFo97 -nh6Vfe63SKMI2tavegw5BmV/Sl0fvBf4q77uKNd0f3p4mVmFaG5cIzJLv07A6Fpt -43C/dxC//AH2hdmoRBBYMql1GNXRor5H4idq9Joz+EkIYIvUX7Q6hL+hqkpMfT7P -T19sdl6gSzeRntwi5m3OFBqOasv+zbMUZBfHWymeMr/y7vrTC0LUq7dBMtoM1O/4 -gdW7jVg/tRvoSSiicNoxBN33shbyTApOB6jtSj1etX+jkMOvJwIDAQABo2MwYTAO -BgNVHQ8BAf8EBAMCAYYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUA95QNVbR -TLtm8KPiGxvDl7I90VUwHwYDVR0jBBgwFoAUA95QNVbRTLtm8KPiGxvDl7I90VUw -DQYJKoZIhvcNAQEFBQADggEBAMucN6pIExIK+t1EnE9SsPTfrgT1eXkIoyQY/Esr -hMAtudXH/vTBH1jLuG2cenTnmCmrEbXjcKChzUyImZOMkXDiqw8cvpOp/2PV5Adg -06O/nVsJ8dWO41P0jmP6P6fbtGbfYmbW0W5BjfIttep3Sp+dWOIrWcBAI+0tKIJF -PnlUkiaY4IBIqDfv8NZ5YBberOgOzW6sRBc4L0na4UU+Krk2U886UAb3LujEV0ls -YSEY1QSteDwsOoBrp+uvFRTp2InBuThs4pFsiv9kuXclVzDAGySj4dzp30d8tbQk -CAUw7C29C79Fv1C5qfPrmAESrciIxpg0X40KPMbp1ZWVbd4= ------END CERTIFICATE----- diff --git a/examples/tokio_tls_v0.2.2/src/main.rs b/examples/tokio_tls_v0.2.2/src/main.rs deleted file mode 100644 index 0cc1da9..0000000 --- a/examples/tokio_tls_v0.2.2/src/main.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::{ - io::{BufReader, Cursor}, - sync::Arc, - time::Duration, -}; - -use async_trait::async_trait; -use mqrstt::{ - new_tokio, - packets::{self, Packet}, - tokio::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; -use tokio_rustls::rustls::{Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; - -pub const EMQX_CERT: &[u8] = include_bytes!("broker.emqx.io-ca.crt"); - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -#[derive(Debug, Clone)] -pub enum PrivateKey { - RSA(Vec), - ECC(Vec), -} - -pub fn simple_rust_tls(ca: Vec, alpn: Option>>, client_auth: Option<(Vec, PrivateKey)>) -> Result, rustls::Error> { - let mut root_cert_store = RootCertStore::empty(); - - let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca))).unwrap(); - - let trust_anchors = ca_certs.iter().map_while(|cert| { - if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { - Some(OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)) - } else { - None - } - }); - root_cert_store.add_server_trust_anchors(trust_anchors); - - assert!(!root_cert_store.is_empty()); - - let config = ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store); - - let mut config = match client_auth { - Some((client_cert_info, client_private_info)) => { - let read_private_keys = match client_private_info { - PrivateKey::RSA(rsa) => rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))), - PrivateKey::ECC(ecc) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))), - } - .unwrap(); - - let key = read_private_keys.into_iter().next().unwrap(); - - let client_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info))).unwrap(); - let client_cert_chain = client_certs.into_iter().map(Certificate).collect(); - - config.with_single_cert(client_cert_chain, tokio_rustls::rustls::PrivateKey(key)).unwrap() - } - None => config.with_no_client_auth(), - }; - - if let Some(alpn) = alpn { - config.alpn_protocols.extend(alpn) - } - - Ok(Arc::new(config)) -} - -#[tokio::main] -async fn main() { - let client_id = "TokioTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_tokio(options); - - let arc_client_config = simple_rust_tls(EMQX_CERT.to_vec(), None, None).unwrap(); - - let domain = ServerName::try_from(address).unwrap(); - let connector = tokio_rustls::TlsConnector::from(arc_client_config); - - let stream = tokio::net::TcpStream::connect((address, port)).await.unwrap(); - let connection = connector.connect(domain, stream).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(connection, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = tokio::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - tokio::time::sleep(Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); -} diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 0000000..1a45eee --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,4 @@ +target +corpus +artifacts +coverage diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 0000000..efee995 --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "mqrstt-fuzz" +version = "0.0.0" +publish = false +edition = "2021" +license = "MIT" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = "0.4" + +bytes = "1" + +tokio = { version = "1", features = ["full"] } + +[dependencies.mqrstt] +path = "../mqrstt" + +[[bin]] +name = "fuzz_target_1" +path = "fuzz_targets/fuzz_target_1.rs" +test = false +doc = false +bench = false diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs new file mode 100644 index 0000000..2dc7634 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -0,0 +1,15 @@ +#![no_main] + +#[cfg(target_os = "linux")] +use libfuzzer_sys::fuzz_target; + +#[cfg(target_os = "linux")] +#[tokio::main(flavor = "current_thread")] +async fn test(mut data: &[u8]) { + let _ = mqrstt::packets::Packet::async_read(&mut data).await; +} + +#[cfg(target_os = "linux")] +fuzz_target!(|data: &[u8]| { + test(data); +}); diff --git a/mqrstt/Cargo.toml b/mqrstt/Cargo.toml index ba4a6d0..0b7d79a 100644 --- a/mqrstt/Cargo.toml +++ b/mqrstt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mqrstt" -version = "0.3.0" +version = "0.4.0" homepage = "https://github.com/GunnarMorrigan/mqrstt" repository = "https://github.com/GunnarMorrigan/mqrstt" documentation = "https://docs.rs/mqrstt" @@ -10,7 +10,6 @@ edition = "2021" license = "MPL-2.0" keywords = ["MQTT", "IoT", "MQTTv5", "messaging", "client"] description = "Pure rust MQTTv5 client implementation Smol and Tokio" - rust-version = "1.75" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -28,24 +27,21 @@ test = [] [dependencies] # Packets -bytes = "1.5.0" +bytes = "1" # Errors -thiserror = "1.0.53" -tracing = { version = "0.1.40", optional = true } +thiserror = "1" +tracing = { version = "0.1", optional = true } -async-channel = "2.1.1" -#async-mutex = "1.4.0" -futures = { version = "0.3.30", default-features = false, features = [ +async-channel = "2" +futures = { version = "0.3", default-features = false, features = [ "std", "async-await", ] } -# quic feature flag -# quinn = {version = "0.9.0", optional = true } # tokio feature flag -tokio = { version = "1.35.1", features = [ +tokio = { version = "1", features = [ "macros", "io-util", "net", @@ -53,10 +49,10 @@ tokio = { version = "1.35.1", features = [ ], optional = true } # smol feature flag -smol = { version = "2.0.0", optional = true } +smol = { version = "2", optional = true } [dev-dependencies] -criterion = { version = "0.5.1", features = ["async_tokio"] } +pretty_assertions = "1.4.1" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } @@ -76,10 +72,5 @@ rustls-pemfile = { version = "1.0.3" } webpki = { version = "0.22.4" } async-rustls = { version = "0.4.1" } tokio-rustls = "0.24.1" -rstest = "0.18.2" -rand = "0.8.5" - - -[[bench]] -name = "bench_main" -harness = false +rstest = "0.23.0" +rand = "0.8.5" \ No newline at end of file diff --git a/README.md b/mqrstt/README.md similarity index 58% rename from README.md rename to mqrstt/README.md index 58dff00..a331db9 100644 --- a/README.md +++ b/mqrstt/README.md @@ -10,8 +10,7 @@ `MQRSTT` is an MQTTv5 client that provides sync and async (smol and tokio) implementation. Because this crate aims to be runtime agnostic the user is required to provide their own data stream. -For an async approach the stream has to implement the smol or tokio [`AsyncReadExt`] and [`AsyncWriteExt`] traits. -For a sync approach the stream has to implement the [`std::io::Read`] and [`std::io::Write`] traits. +The stream has to implement the smol or tokio [`AsyncReadExt`] and [`AsyncWrite`] traits. @@ -22,11 +21,13 @@ For a sync approach the stream has to implement the [`std::io::Read`] and [`std: - TLS/TCP - Lean - Keep alive depends on actual communication +- This tokio implemention has been fuzzed using cargo-fuzz! + + ### To do -- no_std (Requires a lot of work to use no heap allocations and depend on stack) - Even More testing -- More documentation +- Add TLS examples to repository ## MSRV From 0.3 the tokio and smol variants will require MSRV: 1.75 due to async fn in trait feature. @@ -38,119 +39,90 @@ From 0.3 the tokio and smol variants will require MSRV: 1.75 due to async fn in - Create a new connection when an error or disconnect is encountered - Handlers only get incoming packets -### TLS: -TLS examples are too larger for a README. [TLS examples](https://github.com/GunnarMorrigan/mqrstt/tree/main/examples). ### Smol example: ```rust use mqrstt::{ - MqttClient, - ConnectOptions, - new_smol, packets::{self, Packet}, - AsyncEventHandler, - smol::NetworkStatus, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, }; -use bytes::Bytes; pub struct PingPong { pub client: MqttClient, } impl AsyncEventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { if payload.to_lowercase().contains("ping") { - self.client - .publish( - p.topic.clone(), - p.qos, - p.retain, - Bytes::from_static(b"pong"), - ) - .await - .unwrap(); + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); println!("Received Ping, Send pong!"); } } - }, - Packet::ConnAck(_) => { println!("Connected!") }, + } + Packet::ConnAck(_) => { + println!("Connected!") + } _ => (), } } } -smol::block_on(async { - let options = ConnectOptions::new("mqrsttSmolExample"); - let (mut network, client) = new_smol(options); - let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) - .await - .unwrap(); - - let mut pingpong = PingPong { - client: client.clone(), - }; +fn main() { + smol::block_on(async { + let (mut network, client) = NetworkBuilder::new_from_client_id("mqrsttSmolExample").smol_network(); + let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); - network.connect(stream, &mut pingpong).await.unwrap(); + let mut pingpong = PingPong { client: client.clone() }; - // This subscribe is only processed when we run the network - client.subscribe("mqrstt").await.unwrap(); + network.connect(stream, &mut pingpong).await.unwrap(); + + // This subscribe is only processed when we run the network + client.subscribe("mqrstt").await.unwrap(); + + let task_handle = smol::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + smol::Timer::after(std::time::Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = task_handle.await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); + }); +} - let (n, t) = futures::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); -}); ``` ### Tokio example: ```rust use mqrstt::{ - MqttClient, - ConnectOptions, - new_tokio, packets::{self, Packet}, - AsyncEventHandler, - tokio::NetworkStatus, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, }; use tokio::time::Duration; -use bytes::Bytes; pub struct PingPong { pub client: MqttClient, } impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { + // Handlers only get INCOMING packets. + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { if payload.to_lowercase().contains("ping") { - self.client - .publish( - p.topic.clone(), - p.qos, - p.retain, - Bytes::from_static(b"pong"), - ) - .await - .unwrap(); + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); println!("Received Ping, Send pong!"); } } - }, - Packet::ConnAck(_) => { println!("Connected!") }, + } + Packet::ConnAck(_) => { + println!("Connected!") + } _ => (), } } @@ -158,39 +130,30 @@ impl AsyncEventHandler for PingPong { #[tokio::main] async fn main() { - let options = ConnectOptions::new("TokioTcpPingPongExample"); - - let (mut network, client) = new_tokio(options); - - let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) - .await - .unwrap(); - - let mut pingpong = PingPong { - client: client.clone(), - }; - + let (mut network, client) = NetworkBuilder::new_from_client_id("TokioTcpPingPongExample").tokio_network(); + + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + let stream = tokio::io::BufStream::new(stream); + + let mut pingpong = PingPong { client: client.clone() }; + network.connect(stream, &mut pingpong).await.unwrap(); - + client.subscribe("mqrstt").await.unwrap(); - - - let (n, _) = tokio::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - tokio::time::sleep(Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); + + let network_handle = tokio::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + tokio::time::sleep(Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = network_handle.await.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); } + ``` ### Sync example: @@ -212,7 +175,7 @@ pub struct PingPong { impl EventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - fn handle(&mut self, event: packets::Packet) -> () { + fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { @@ -284,7 +247,6 @@ Licensed under * Mozilla Public License, Version 2.0, [(MPL-2.0)](https://choosealicense.com/licenses/mpl-2.0/) ## Contribution - Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, shall be licensed under MPL-2.0, without any additional terms or conditions. diff --git a/mqrstt/benches/bench_main.rs b/mqrstt/benches/bench_main.rs deleted file mode 100644 index 9d7f5e4..0000000 --- a/mqrstt/benches/bench_main.rs +++ /dev/null @@ -1,8 +0,0 @@ -use criterion::criterion_main; - -mod benchmarks; - -criterion_main! { - benchmarks::tokio::tokio_concurrent, - benchmarks::tokio::tokio_synchronous, -} diff --git a/mqrstt/benches/benchmarks/mod.rs b/mqrstt/benches/benchmarks/mod.rs deleted file mode 100644 index 6a66044..0000000 --- a/mqrstt/benches/benchmarks/mod.rs +++ /dev/null @@ -1,158 +0,0 @@ -use bytes::{BufMut, Bytes, BytesMut}; -use mqrstt::packets::{Disconnect, Packet, Publish}; - -pub mod tokio; - -fn fill_stuff(buffer: &mut BytesMut, publ_count: usize, publ_size: usize) { - empty_connect(buffer); - for i in 0..publ_count { - very_large_publish(i as u16, publ_size / 5).write(buffer).unwrap(); - } - empty_disconnect().write(buffer).unwrap(); -} - -fn empty_disconnect() -> Packet { - let discon = Disconnect { - reason_code: mqrstt::packets::reason_codes::DisconnectReasonCode::ServerBusy, - properties: Default::default(), - }; - - Packet::Disconnect(discon) -} - -fn empty_connect(buffer: &mut BytesMut) { - // let conn_ack = ConnAck{ - // connack_flags: ConnAckFlags::default(), - // reason_code: mqrstt::packets::reason_codes::ConnAckReasonCode::Success, - // connack_properties: Default::default(), - // }; - - // Packet::ConnAck(conn_ack) - // buffer.put_u8(0b0010_0000); // Connack flags - // buffer.put_u8(0x01); // Connack flags - // buffer.put_u8(0x00); // Reason code, - // buffer.put_u8(0x00); // empty properties - - buffer.put_u8(0x20); - buffer.put_u8(0x13); - buffer.put_u8(0x00); - buffer.put_u8(0x00); - buffer.put_u8(0x10); - buffer.put_u8(0x27); - buffer.put_u8(0x06); - buffer.put_u8(0x40); - buffer.put_u8(0x00); - buffer.put_u8(0x00); - buffer.put_u8(0x25); - buffer.put_u8(0x01); - buffer.put_u8(0x2a); - buffer.put_u8(0x01); - buffer.put_u8(0x29); - buffer.put_u8(0x01); - buffer.put_u8(0x22); - buffer.put_u8(0xff); - buffer.put_u8(0xff); - buffer.put_u8(0x28); - buffer.put_u8(0x01); -} - -/// Returns Publish Packet with 5x `repeat` as payload in bytes. -fn very_large_publish(id: u16, repeat: usize) -> Packet { - let publ = Publish { - dup: false, - qos: mqrstt::packets::QoS::ExactlyOnce, - retain: false, - topic: "BlaBla".into(), - packet_identifier: Some(id), - publish_properties: Default::default(), - payload: Bytes::from_iter("ping".repeat(repeat).into_bytes()), - }; - - Packet::Publish(publ) -} - -mod test_handlers { - use std::{ - sync::{atomic::AtomicU16, Arc}, - time::Duration, - }; - - use bytes::Bytes; - use mqrstt::{ - packets::{self, Packet}, - AsyncEventHandler, AsyncEventHandlerMut, MqttClient, - }; - - pub struct PingPong { - pub client: MqttClient, - pub number: Arc, - } - - impl PingPong { - pub fn new(client: MqttClient) -> Self { - Self { - client, - number: Arc::new(AtomicU16::new(0)), - } - } - } - - impl AsyncEventHandler for PingPong { - async fn handle(&self, event: packets::Packet) -> () { - self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - let max_len = payload.len().min(10); - let _a = &payload[0..max_len]; - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); - } - } - } - Packet::ConnAck(_) => (), - _ => (), - } - } - } - - impl AsyncEventHandlerMut for PingPong { - async fn handle(&mut self, event: packets::Packet) -> () { - self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - let max_len = payload.len().min(10); - let _a = &payload[0..max_len]; - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); - } - } - } - Packet::ConnAck(_) => (), - _ => (), - } - } - } - - pub struct SimpleDelay { - delay: Duration, - } - - impl SimpleDelay { - pub fn new(delay: Duration) -> Self { - Self { delay } - } - } - - impl AsyncEventHandler for SimpleDelay { - fn handle(&self, _: Packet) -> impl futures::prelude::Future + Send + Sync { - tokio::time::sleep(self.delay) - } - } - impl AsyncEventHandlerMut for SimpleDelay { - fn handle(&mut self, _: Packet) -> impl futures::prelude::Future + Send + Sync { - tokio::time::sleep(self.delay) - } - } -} diff --git a/mqrstt/benches/benchmarks/tokio.rs b/mqrstt/benches/benchmarks/tokio.rs deleted file mode 100644 index cdd1950..0000000 --- a/mqrstt/benches/benchmarks/tokio.rs +++ /dev/null @@ -1,285 +0,0 @@ -use std::{hint::black_box, io::Write, net::SocketAddr, sync::Arc, time::Duration}; - -use bytes::BytesMut; -use criterion::{criterion_group, Criterion}; -use mqrstt::{ConnectOptions, NetworkBuilder, NetworkStatus}; -use tokio::net::TcpStream; - -use crate::benchmarks::test_handlers::{PingPong, SimpleDelay}; - -use super::fill_stuff; - -fn tokio_setup() -> (TcpStream, std::net::TcpStream, SocketAddr) { - let mut buffer = BytesMut::new(); - - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let tcp_stream = std::net::TcpStream::connect(addr).unwrap(); - - let (mut server, _addr) = listener.accept().unwrap(); - - fill_stuff(&mut buffer, 100, 5_000_000); - - server.write_all(&buffer.to_vec()).unwrap(); - - let tcp_stream = tokio::net::TcpStream::from_std(tcp_stream).unwrap(); - (tcp_stream, server, _addr) -} - -fn tokio_concurrent_benchmarks(c: &mut Criterion) { - let mut group = c.benchmark_group("Tokio concurrent read, write and handling"); - group.sample_size(30); - group.measurement_time(Duration::from_secs(120)); - - group.bench_function("tokio_bench_concurrent_read_write_and_handling_NOP", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut ()).await.unwrap(); - let (read, write) = network.split(()).unwrap(); - - let _network_box = black_box(network); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = tokio::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - let read_res = read_res.unwrap(); - assert_eq!(read_res, NetworkStatus::IncomingDisconnect); - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_and_handling_PingPong", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); - - let mut pingpong = Arc::new(PingPong::new(client.clone())); - - network.connect(tcp_stream, &mut pingpong).await.unwrap(); - let (read, write) = network.split(pingpong.clone()).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = futures::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - let read_res = read_res.unwrap(); - assert_eq!(read_res, NetworkStatus::IncomingDisconnect); - assert_eq!(102, pingpong.number.load(std::sync::atomic::Ordering::SeqCst)); - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _server_box = black_box(client.clone()); - let _server_box = black_box(server); - let _addr_box = black_box(addr); - let _network_box = black_box(network); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_and_handling_100ms_Delay", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); - - let _server_box = black_box(client); - - let mut handler = Arc::new(SimpleDelay::new(Duration::from_millis(100))); - - network.connect(tcp_stream, &mut handler).await.unwrap(); - let (read, write) = network.split(handler).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = tokio::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _network_box = black_box(network); - }) - }); - - group.bench_function("tokio_bench_concurrent_read_write", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut ()).await.unwrap(); - - let (read, write) = network.split(()).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = tokio::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_PingPong", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let mut pingpong = PingPong::new(client.clone()); - - let num_packets_received = pingpong.number.clone(); - - network.connect(tcp_stream, &mut pingpong).await.unwrap(); - let (read, write) = network.split(pingpong).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = futures::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - assert_eq!(102, num_packets_received.load(std::sync::atomic::Ordering::SeqCst)); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _server_box = black_box(client.clone()); - let _server_box = black_box(server); - let _addr_box = black_box(addr); - let _network_box = black_box(network); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_100ms_Delay", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let _server_box = black_box(client); - - let mut handler = SimpleDelay::new(Duration::from_millis(100)); - - network.connect(tcp_stream, &mut handler).await.unwrap(); - let (read, write) = network.split(handler).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = futures::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _network_box = black_box(network); - }) - }); -} - -fn tokio_synchronous_benchmarks(c: &mut Criterion) { - let mut group = c.benchmark_group("Tokio sequential"); - group.sample_size(30); - group.measurement_time(Duration::from_secs(120)); - - group.bench_function("tokio_bench_sync_read_write", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut ()).await.unwrap(); - - let network_res = network.run(&mut ()).await; - - assert!(network_res.is_ok()); - let network_res = network_res.unwrap(); - assert_eq!(network_res, NetworkStatus::IncomingDisconnect); - }) - }); - group.bench_function("tokio_bench_sync_read_write_PingPong", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let mut pingpong = PingPong::new(client.clone()); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut pingpong).await.unwrap(); - - let network_res = network.run(&mut pingpong).await; - - assert!(network_res.is_ok()); - let network_res = network_res.unwrap(); - assert_eq!(network_res, NetworkStatus::IncomingDisconnect); - }) - }); - group.bench_function("tokio_bench_sync_read_write_100ms_Delay", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let mut handler = SimpleDelay::new(Duration::from_millis(100)); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut handler).await.unwrap(); - - let network_res = network.run(&mut handler).await; - - assert!(network_res.is_ok()); - let network_res = network_res.unwrap(); - assert_eq!(network_res, NetworkStatus::IncomingDisconnect); - }) - }); -} - -criterion_group!(tokio_concurrent, tokio_concurrent_benchmarks); -criterion_group!(tokio_synchronous, tokio_synchronous_benchmarks); diff --git a/mqrstt/src/available_packet_ids.rs b/mqrstt/src/available_packet_ids.rs index 1527d8e..6a0e28b 100644 --- a/mqrstt/src/available_packet_ids.rs +++ b/mqrstt/src/available_packet_ids.rs @@ -6,7 +6,7 @@ use tracing::{debug, error}; use crate::error::HandlerError; #[derive(Debug, Clone)] -pub struct AvailablePacketIds { +pub(crate) struct AvailablePacketIds { sender: Sender, } @@ -22,7 +22,7 @@ impl AvailablePacketIds { (apkid, r) } - pub fn mark_available(&self, pkid: u16) -> Result<(), HandlerError> { + pub(crate) fn mark_available(&self, pkid: u16) -> Result<(), HandlerError> { match self.sender.try_send(pkid) { Ok(_) => { #[cfg(feature = "logs")] diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index d886114..1e7c72d 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -1,5 +1,4 @@ use async_channel::{Receiver, Sender}; -use bytes::Bytes; #[cfg(feature = "logs")] use tracing::info; @@ -7,16 +6,18 @@ use tracing::info; use crate::{ error::ClientError, packets::{ - mqtt_traits::PacketValidation, - reason_codes::DisconnectReasonCode, - Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties}, {Subscribe, SubscribeProperties, Subscription}, {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, + mqtt_trait::PacketValidation, + DisconnectReasonCode, Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties}, {Subscribe, SubscribeProperties, SubscribeTopics}, + {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, }, }; #[derive(Debug, Clone)] -/// A Clonable client that can be used to perform MQTT operations -/// -/// This object is never self constructed but is a obtained by calling the builder functions on [`crate::NetworkBuilder`] +/// A Clonable client that can be used to send MQTT messages. +/// +/// This object can be obtained by calling the builder functions on [`crate::NetworkBuilder`] +/// +/// This client should be used in combination with a handler [`crate::AsyncEventHandler`] to receive and send messages. pub struct MqttClient { /// Provides this client with an available packet id or waits on it. available_packet_ids_r: Receiver, @@ -36,69 +37,17 @@ impl MqttClient { max_packet_size, } } - - /// This function is only here for you to use during testing of for example your handler - /// For a simple client look at [`MqttClient::test_client`] - #[cfg(feature = "test")] - pub fn test_custom_client(available_packet_ids_r: Receiver, to_network_s: Sender, max_packet_size: usize) -> Self { - Self { - available_packet_ids_r, - to_network_s, - max_packet_size, - } - } - - - - /// This function is only here for you to use during testing of for example your handler - /// For control over the input of this type look at [`MqttClient::test_custom_client`] - /// - /// The returned values should not be dropped otherwise the client won't be able to operate normally. - /// - /// # Example - /// ```ignore - /// let ( - /// client, // An instance of this client - /// ids, // Allows you to indicate which packet IDs have become available again. - /// network_receiver // Messages send through the `client` will be dispatched through this channel - /// ) = MqttClient::test_client(); - /// - /// // perform testing - /// - /// // Make sure to not drop these before the test is done! - /// std::hint::black_box((ids, network_receiver)); - /// ``` - #[cfg(feature = "test")] - pub fn test_client() -> (Self, crate::available_packet_ids::AvailablePacketIds, Receiver) { - use async_channel::unbounded; - - use crate::{available_packet_ids::AvailablePacketIds, util::constants::MAXIMUM_PACKET_SIZE}; - - let (available_packet_ids, available_packet_ids_r) = AvailablePacketIds::new(u16::MAX); - - let (s, r) = unbounded(); - - ( - Self { - available_packet_ids_r, - to_network_s: s, - max_packet_size: MAXIMUM_PACKET_SIZE as usize, - }, - available_packet_ids, - r, - ) - } } /// Async functions to perform MQTT operations impl MqttClient { /// Creates a subscribe packet that is then asynchronously transferred to the Network stack for transmission /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -125,34 +74,36 @@ impl MqttClient { /// mqtt_client.subscribe(("final/test/topic", sub_options)).await; /// # }); /// ``` - pub async fn subscribe>(&self, into_subscribtions: A) -> Result<(), ClientError> { + pub async fn subscribe>(&self, into_subscribtions: A) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?; - let subscription: Subscription = into_subscribtions.into(); + let subscription: SubscribeTopics = into_subscribtions.into(); let sub = Subscribe::new(pkid, subscription.0); sub.validate(self.max_packet_size)?; self.to_network_s.send(Packet::Subscribe(sub)).await.map_err(|_| ClientError::NoNetworkChannel)?; + #[cfg(feature = "logs")] + info!("Send to network: Subscribe with ID {:?}", pkid); Ok(()) } /// Creates a subscribe packet with additional subscribe packet properties. /// The packet is then asynchronously transferred to the Network stack for transmission. /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; /// use mqrstt::packets::{SubscribeProperties, SubscriptionOptions, RetainHandling}; /// /// let sub_properties = SubscribeProperties{ - /// subscription_id: Some(1), + /// subscription_identifier: Some(1), /// user_properties: vec![], /// }; - /// + /// /// let sub_properties_clone = sub_properties.clone(); /// /// // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::AtMostOnce, @@ -184,7 +135,7 @@ impl MqttClient { /// mqtt_client.subscribe_with_properties(("final/test/topic", sub_options), sub_properties).await; /// # }); /// ``` - pub async fn subscribe_with_properties>(&self, into_sub: S, properties: SubscribeProperties) -> Result<(), ClientError> { + pub async fn subscribe_with_properties>(&self, into_sub: S, properties: SubscribeProperties) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?; let sub = Subscribe { packet_identifier: pkid, @@ -202,7 +153,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -227,7 +178,7 @@ impl MqttClient { /// /// # }); /// ``` - pub async fn publish, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { + pub async fn publish, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?), @@ -252,11 +203,11 @@ impl MqttClient { /// Creates a Publish packet with additional publish properties. /// The packet is then asynchronously transferred to the Network stack for transmission. /// - /// Can be called with any payload that can be converted into [`Bytes`] + /// Can be called with any payload that can be converted into [`Vec`] /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -268,7 +219,7 @@ impl MqttClient { /// correlation_data: Some("correlation_data".into()), /// ..Default::default() /// }; - /// + /// /// # let properties_clone = properties.clone(); /// /// // publish a message with QoS 0, without a packet identifier @@ -299,7 +250,7 @@ impl MqttClient { /// # }); /// # let _network = std::hint::black_box(network); /// ``` - pub async fn publish_with_properties, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { + pub async fn publish_with_properties, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?), @@ -326,7 +277,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// // Unsubscribe from a single topic specified as a string: @@ -373,13 +324,13 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::UnsubscribeProperties; /// /// let properties = UnsubscribeProperties{ - /// user_properties: vec![("property".to_string(), "value".to_string())], + /// user_properties: vec![("property".into(), "value".into())], /// }; /// /// // Unsubscribe from a single topic specified as a string: @@ -387,7 +338,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topic, properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from multiple topics specified as an array of string slices: @@ -395,7 +346,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topics.as_slice(), properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from a single topic specified as a String: @@ -403,7 +354,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topic, properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from multiple topics specified as a Vec: @@ -411,7 +362,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topics, properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from multiple topics specified as an array of String: @@ -442,7 +393,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// mqtt_client.disconnect().await.unwrap(); @@ -468,11 +419,11 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; - /// use mqrstt::packets::reason_codes::DisconnectReasonCode; + /// use mqrstt::packets::DisconnectReasonCode; /// /// let properties = DisconnectProperties { /// reason_string: Some("Reason here".into()), @@ -495,16 +446,16 @@ impl MqttClient { impl MqttClient { /// Creates a subscribe packet that is then transferred to the Network stack for transmission /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// This function blocks until the packet is queued for transmission /// Creates a subscribe packet that is then asynchronously transferred to the Network stack for transmission /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// use mqrstt::packets::QoS; /// use mqrstt::packets::{SubscriptionOptions, RetainHandling}; @@ -530,9 +481,9 @@ impl MqttClient { /// mqtt_client.subscribe_blocking(("final/test/topic", sub_options)).unwrap(); /// # }); /// ``` - pub fn subscribe_blocking>(&self, into_subscribtions: A) -> Result<(), ClientError> { + pub fn subscribe_blocking>(&self, into_subscribtions: A) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?; - let subscription: Subscription = into_subscribtions.into(); + let subscription: SubscribeTopics = into_subscribtions.into(); let sub = Subscribe::new(pkid, subscription.0); sub.validate(self.max_packet_size)?; @@ -543,22 +494,22 @@ impl MqttClient { /// Creates a subscribe packet with additional subscribe packet properties. /// The packet is then transferred to the Network stack for transmission. /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// This function blocks until the packet is queued for transmission /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// use mqrstt::packets::QoS; /// use mqrstt::packets::{SubscribeProperties, SubscriptionOptions, RetainHandling}; /// /// let sub_properties = SubscribeProperties{ - /// subscription_id: Some(1), + /// subscription_identifier: Some(1), /// user_properties: vec![], /// }; /// # let sub_properties_clone = sub_properties.clone(); - /// + /// /// // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::AtMostOnce, /// mqtt_client.subscribe_with_properties_blocking("test/topic", sub_properties).unwrap(); /// @@ -588,7 +539,7 @@ impl MqttClient { /// mqtt_client.subscribe_with_properties_blocking(("final/test/topic", sub_options), sub_properties).unwrap(); /// # }); /// ``` - pub fn subscribe_with_properties_blocking>(&self, into_subscribtions: S, properties: SubscribeProperties) -> Result<(), ClientError> { + pub fn subscribe_with_properties_blocking>(&self, into_subscribtions: S, properties: SubscribeProperties) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?; let sub = Subscribe { packet_identifier: pkid, @@ -603,14 +554,14 @@ impl MqttClient { /// Creates a Publish packet which is then transferred to the Network stack for transmission. /// - /// Can be called with any payload that can be converted into [`Bytes`] + /// Can be called with any payload that can be converted into [`Vec`] /// /// This function blocks until the packet is queued for transmission /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { - /// + /// /// use mqrstt::packets::QoS; /// use bytes::Bytes; /// @@ -633,7 +584,7 @@ impl MqttClient { /// /// # }); /// ``` - pub fn publish_blocking, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { + pub fn publish_blocking, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?), @@ -658,15 +609,15 @@ impl MqttClient { /// Creates a Publish packet with additional publish properties. /// The packet is then transferred to the Network stack for transmission. /// - /// Can be called with any payload that can be converted into [`Bytes`] + /// Can be called with any payload that can be converted into [`Vec`] /// /// This function blocks until the packet is queued for transmission /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { - /// + /// /// use mqrstt::packets::QoS; /// use mqrstt::packets::PublishProperties; /// use bytes::Bytes; @@ -676,7 +627,7 @@ impl MqttClient { /// correlation_data: Some("correlation_data".into()), /// ..Default::default() /// }; - /// + /// /// # let properties_clone = properties.clone(); /// /// // publish a message with QoS 0, without a packet identifier @@ -706,7 +657,7 @@ impl MqttClient { /// /// # }); /// ``` - pub fn publish_with_properties_blocking, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { + pub fn publish_with_properties_blocking, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?), @@ -734,7 +685,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// // Unsubscribe from a single topic specified as a string: @@ -756,7 +707,7 @@ impl MqttClient { /// // Unsubscribe from multiple topics specified as an array of String: /// let topics = &[String::from("test/topic1"), String::from("test/topic2")]; /// mqtt_client.unsubscribe_blocking(topics.as_slice()); - /// + /// /// # }); /// # std::hint::black_box(network); /// ``` @@ -782,13 +733,13 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::UnsubscribeProperties; /// /// let properties = UnsubscribeProperties{ - /// user_properties: vec![("property".to_string(), "value".to_string())], + /// user_properties: vec![("property".into(), "value".into())], /// }; /// # let properties_clone = properties.clone(); /// @@ -819,7 +770,7 @@ impl MqttClient { /// // Unsubscribe from multiple topics specified as an array of String: /// let topics = ["test/topic1","test/topic2"]; /// mqtt_client.unsubscribe_with_properties_blocking(topics.as_slice(), properties); - /// + /// /// # }); /// # std::hint::black_box(network); /// ``` @@ -843,9 +794,9 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { - /// + /// /// mqtt_client.disconnect_blocking().unwrap(); /// /// # }); @@ -869,11 +820,11 @@ impl MqttClient { /// /// ``` /// - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { - /// + /// /// use mqrstt::packets::DisconnectProperties; - /// use mqrstt::packets::reason_codes::DisconnectReasonCode; + /// use mqrstt::packets::DisconnectReasonCode; /// /// let properties = DisconnectProperties { /// reason_string: Some("Reason here".into()), @@ -891,14 +842,14 @@ impl MqttClient { } } -#[cfg(any(feature = "tokio", feature = "smol", feature = "quic"))] +#[cfg(any(feature = "tokio", feature = "smol"))] #[cfg(test)] mod tests { use async_channel::Receiver; use crate::{ error::{ClientError, PacketValidationError}, - packets::{reason_codes::DisconnectReasonCode, DisconnectProperties, Packet, PacketType, Publish, QoS, Subscribe, SubscribeProperties, UnsubscribeProperties}, + packets::{DisconnectProperties, DisconnectReasonCode, Packet, PacketType, Publish, QoS, SubscribeProperties, UnsubscribeProperties}, }; use super::MqttClient; @@ -946,15 +897,15 @@ mod tests { #[tokio::test] async fn test_subscribe_with_properties() { let (mqtt_client, client_to_handler_r, to_network_r) = create_new_test_client(); - - let sub_properties = SubscribeProperties{ - subscription_id: Some(1), + + let sub_properties = SubscribeProperties { + subscription_identifier: Some(1), user_properties: vec![], }; // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::AtMostOnce, let res = mqtt_client.subscribe_with_properties("test/topic", sub_properties.clone()).await; - + assert!(res.is_ok()); let packet = client_to_handler_r.recv().await.unwrap(); // assert!(matches!(packet, Packet::Subscribe(sub) if sub.properties.subscription_id == Some(1))); @@ -964,7 +915,7 @@ mod tests { } #[test] - + fn test_subscribe_blocking() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); @@ -1015,7 +966,6 @@ mod tests { std::hint::black_box((client, client_to_handler_r, to_network_r)); } - #[test] fn test_unsubscribe_blocking() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); @@ -1041,13 +991,12 @@ mod tests { std::hint::black_box((client, client_to_handler_r, to_network_r)); } - #[test] fn test_unsubscribe_with_properties_blocking() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); - let properties = UnsubscribeProperties{ - user_properties: vec![("property".to_string(), "value".to_string())], + let properties = UnsubscribeProperties { + user_properties: vec![("property".into(), "value".into())], }; // Unsubscribe from a single topic specified as a string: @@ -1090,12 +1039,11 @@ mod tests { assert_eq!(res.unwrap_err(), ClientError::ValidationError(PacketValidationError::TopicSize(65538))); } - #[tokio::test] async fn publish_with_properties() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); - let properties = crate::packets::PublishProperties{ + let properties = crate::packets::PublishProperties { response_topic: Some("response/topic".into()), correlation_data: Some("correlation_other_data".into()), ..Default::default() @@ -1107,7 +1055,7 @@ mod tests { assert!(res.is_ok()); let packet = client_to_handler_r.recv().await.unwrap(); - let publ = Publish{ + let publ = Publish { dup: false, qos: *qos, retain: false, @@ -1123,12 +1071,11 @@ mod tests { std::hint::black_box((client, client_to_handler_r, to_network_r)); } - #[tokio::test] async fn publish_with_just_right_topic_len_properties() { let (client, _client_to_handler_r, _) = create_new_test_client(); - let properties = crate::packets::PublishProperties{ + let properties = crate::packets::PublishProperties { response_topic: Some("response/topic".into()), correlation_data: Some("correlation_data".into()), ..Default::default() @@ -1143,7 +1090,7 @@ mod tests { async fn publish_with_too_long_topic_properties() { let (client, _client_to_handler_r, _) = create_new_test_client(); - let properties = crate::packets::PublishProperties{ + let properties = crate::packets::PublishProperties { response_topic: Some("response/topic".into()), correlation_data: Some("correlation_data".into()), ..Default::default() @@ -1193,7 +1140,7 @@ mod tests { let (client, client_to_handler_r, _) = create_new_test_client(); let prop = UnsubscribeProperties { - user_properties: vec![("A".to_string(), "B".to_string())], + user_properties: vec![("A".into(), "B".into())], }; client.unsubscribe_with_properties("Topic", prop.clone()).await.unwrap(); @@ -1218,12 +1165,10 @@ mod tests { let disconnect = client_to_handler_r.recv().await.unwrap(); assert_eq!(PacketType::Disconnect, disconnect.packet_type()); - assert!( - matches!(disconnect, Packet::Disconnect(res) - if res.properties == DisconnectProperties::default() && - DisconnectReasonCode::NormalDisconnection == res.reason_code - ) - ); + assert!(matches!(disconnect, Packet::Disconnect(res) + if res.properties == DisconnectProperties::default() && + DisconnectReasonCode::NormalDisconnection == res.reason_code + )); } #[tokio::test] @@ -1251,7 +1196,6 @@ mod tests { assert!(matches!(disconnect, Packet::Disconnect(res) if properties == res.properties && DisconnectReasonCode::KeepAliveTimeout == res.reason_code)); } - #[test] fn test_disconnect_blocking() { let (client, client_to_handler_r, _) = create_new_test_client(); @@ -1280,5 +1224,4 @@ mod tests { assert!(matches!(disconnect, Packet::Disconnect(res) if properties == res.properties && DisconnectReasonCode::KeepAliveTimeout == res.reason_code)); } - } diff --git a/mqrstt/src/connect_options.rs b/mqrstt/src/connect_options.rs index 15af570..c5a8671 100644 --- a/mqrstt/src/connect_options.rs +++ b/mqrstt/src/connect_options.rs @@ -1,7 +1,5 @@ use std::time::Duration; -use bytes::Bytes; - use crate::util::constants::DEFAULT_RECEIVE_MAXIMUM; use crate::{ packets::{ConnectProperties, LastWill}, @@ -10,10 +8,11 @@ use crate::{ #[derive(Debug, thiserror::Error)] pub enum ConnectOptionsError { - #[error("Maximum packet size is exceeded. Maximum is {MAXIMUM_PACKET_SIZE}, was provided {0}")] - MaximumPacketSize(u32), + #[error("Maximum packet size is exceeded. Maximum is {MAXIMUM_PACKET_SIZE}, user provided: {0}")] + MaximumPacketSizeExceeded(u32), } +/// Options for the connection to the MQTT broker #[derive(Debug, Clone)] pub struct ConnectOptions { /// client identifier @@ -41,7 +40,7 @@ pub struct ConnectOptions { request_problem_information: Option, user_properties: Vec<(Box, Box)>, authentication_method: Option>, - authentication_data: Bytes, + authentication_data: Option>, /// Last will that will be issued on unexpected disconnect last_will: Option, @@ -62,9 +61,9 @@ impl Default for ConnectOptions { topic_alias_maximum: None, request_response_information: None, request_problem_information: None, - user_properties: vec![], + user_properties: Vec::new(), authentication_method: None, - authentication_data: Bytes::new(), + authentication_data: None, last_will: None, } } @@ -72,8 +71,11 @@ impl Default for ConnectOptions { impl ConnectOptions { /// Create a new [`ConnectOptions`] - /// ClientId recommendation: - /// - 1 to 23 bytes UTF-8 bytes + /// + /// Be aware: + /// This client does not restrict the client identifier in any way. However, the MQTT v5.0 specification does. + /// It is thus recommended to use a client id that is compatible with the MQTT v5.0 specification. + /// - 1 to 23 bytes UTF-8 bytes. /// - Contains [a-zA-Z0-9] characters only. /// /// Some brokers accept longer client ids with different characters @@ -94,7 +96,7 @@ impl ConnectOptions { request_problem_information: None, user_properties: vec![], authentication_method: None, - authentication_data: Bytes::new(), + authentication_data: None, last_will: None, } } @@ -219,7 +221,7 @@ impl ConnectOptions { pub fn set_maximum_packet_size(&mut self, maximum_packet_size: u32) -> Result<&mut Self, ConnectOptionsError> { if maximum_packet_size > MAXIMUM_PACKET_SIZE { - Err(ConnectOptionsError::MaximumPacketSize(maximum_packet_size)) + Err(ConnectOptionsError::MaximumPacketSizeExceeded(maximum_packet_size)) } else { self.maximum_packet_size = Some(maximum_packet_size); Ok(self) diff --git a/mqrstt/src/error.rs b/mqrstt/src/error.rs index 6475b33..b439912 100644 --- a/mqrstt/src/error.rs +++ b/mqrstt/src/error.rs @@ -3,9 +3,8 @@ use std::io; use async_channel::{RecvError, SendError}; use crate::packets::{ - error::{DeserializeError, ReadBytes, SerializeError}, - reason_codes::ConnAckReasonCode, - {Packet, PacketType}, + error::{DeserializeError, ReadBytes, ReadError, SerializeError, WriteError}, + ConnAckReasonCode, Packet, PacketType, }; /// Critical errors that can happen during the operation of the entire client @@ -43,7 +42,25 @@ pub enum ConnectionError { JoinError(#[from] tokio::task::JoinError), } -/// Errors that the [`crate::StateHandler`] can emit +impl From for ConnectionError { + fn from(value: ReadError) -> Self { + match value { + ReadError::DeserializeError(deserialize_error) => ConnectionError::DeserializationError(deserialize_error), + ReadError::IoError(error) => ConnectionError::Io(error), + } + } +} + +impl From for ConnectionError { + fn from(value: WriteError) -> Self { + match value { + WriteError::SerializeError(error) => ConnectionError::SerializationError(error), + WriteError::IoError(error) => ConnectionError::Io(error), + } + } +} + +/// Errors that the internal StateHandler can emit #[derive(Debug, Clone, thiserror::Error)] pub enum HandlerError { #[error("Missing Packet ID")] @@ -52,8 +69,8 @@ pub enum HandlerError { #[error("The incoming channel between network and handler is closed")] IncomingNetworkChannelClosed, - #[error("The outgoing channel between handler and network is closed: {0}")] - OutgoingNetworkChannelClosed(#[from] SendError), + #[error("The outgoing channel between handler and network is closed")] + OutgoingNetworkChannelClosed, #[error("Channel between client and handler closed")] ClientChannelClosed, @@ -71,6 +88,12 @@ pub enum HandlerError { UnexpectedPacket(PacketType), } +impl From> for HandlerError { + fn from(_: SendError) -> Self { + HandlerError::OutgoingNetworkChannelClosed + } +} + /// Errors producable by the [`crate::MqttClient`] #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] pub enum ClientError { diff --git a/mqrstt/src/event_handlers.rs b/mqrstt/src/event_handlers.rs index 5ef33f4..b69fca6 100644 --- a/mqrstt/src/event_handlers.rs +++ b/mqrstt/src/event_handlers.rs @@ -1,50 +1,19 @@ -use std::sync::Arc; - use futures::Future; use crate::packets::Packet; - -/// Handlers are used to deal with packets before they are further processed (acked) -/// This guarantees that the end user has handlded the packet. -/// Trait for async mutable access to handler. -/// Usefull when you have a single handler - -/// This trait can be used types which +/// Handlers are used to deal with packets before they are acknowledged to the broker. +/// This guarantees that the end user has handlded the packet. Additionally, handlers only deal with incoming packets. +/// +/// This handler can be used to handle message sequentialy. +/// +/// To send messages look at [`crate::MqttClient`] pub trait AsyncEventHandler { - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync; -} -impl AsyncEventHandler for &T -where - T: AsyncEventHandler, -{ - #[inline] - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync { - AsyncEventHandler::handle(*self, incoming_packet) - } -} -impl AsyncEventHandler for Arc -where - T: AsyncEventHandler, -{ - #[inline] - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync { - ::handle(&self, incoming_packet) - } -} -impl AsyncEventHandler for () { - fn handle(&self, _: Packet) -> impl Future + Send + Sync { - async {} - } -} - -pub trait AsyncEventHandlerMut { fn handle(&mut self, incoming_packet: Packet) -> impl Future + Send + Sync; } -impl AsyncEventHandlerMut for () { - fn handle(&mut self, _: Packet) -> impl Future + Send + Sync { - async {} - } +/// This is a simple no operation handler. +impl AsyncEventHandler for () { + async fn handle(&mut self, _: Packet) {} } pub trait EventHandler { @@ -62,14 +31,14 @@ pub mod example_handlers { use crate::{ packets::{self, Packet}, - AsyncEventHandler, AsyncEventHandlerMut, EventHandler, MqttClient, + AsyncEventHandler, EventHandler, MqttClient, }; /// Most basic no op handler /// This handler performs no operations on incoming messages. pub struct NOP {} - impl AsyncEventHandlerMut for NOP { + impl AsyncEventHandler for NOP { async fn handle(&mut self, _: Packet) {} } @@ -79,23 +48,20 @@ pub mod example_handlers { pub struct PingResp { pub client: MqttClient, - pub ping_resp_received: AtomicU16, + pub ping_resp_received: u32, } impl PingResp { pub fn new(client: MqttClient) -> Self { - Self { - client, - ping_resp_received: AtomicU16::new(0), - } + Self { client, ping_resp_received: 0 } } } - impl AsyncEventHandlerMut for PingResp { - async fn handle(&mut self, event: packets::Packet) -> () { + impl AsyncEventHandler for PingResp { + async fn handle(&mut self, event: packets::Packet) { use Packet::*; if event == PingResp { - self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.ping_resp_received += 1; } println!("Received packet: {}", event); } @@ -105,7 +71,7 @@ pub mod example_handlers { fn handle(&mut self, event: Packet) { use Packet::*; if event == PingResp { - self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.ping_resp_received += 1; } println!("Received packet: {}", event); } @@ -123,34 +89,7 @@ pub mod example_handlers { } impl AsyncEventHandler for PingPong { - async fn handle(&self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - // let max_len = payload.len().min(10); - // let a = &payload[0..max_len]; - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); - // println!("Received publish payload: {}", a); - - if !p.retain { - self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - - // println!("DBG: \n {}", &Packet::Publish(p)); - } - } - } - Packet::ConnAck(_) => { - // println!("Connected!") - } - _ => (), - } - } - } - - impl AsyncEventHandlerMut for PingPong { - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index 1de3c95..e5ac218 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -1,31 +1,26 @@ //! A pure rust MQTT client which is easy to use, efficient and provides both sync and async options. //! //! Because this crate aims to be runtime agnostic the user is required to provide their own data stream. -//! For an async approach the stream has to implement the `AsyncReadExt` and `AsyncWriteExt` traits. -//! That is [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] for tokio and [`::smol::io::AsyncReadExt`] and [`::smol::io::AsyncWriteExt`] for smol. +//! For an async approach the stream has to implement the `AsyncRead` and `AsyncWrite` traits. +//! That is [`::tokio::io::AsyncRead`] and [`::tokio::io::AsyncWrite`] for tokio and [`::smol::io::AsyncRead`] and [`::smol::io::AsyncWrite`] for smol. //! //! Features: //! ---------------------------- //! - MQTT v5 //! - Runtime agnostic (Smol, Tokio) -//! - TLS/TCP +//! - Packets are acknoledged after handler has processed them +//! - Runs on just a stream so you can use all TCP backends //! - Lean //! - Keep alive depends on actual communication //! //! To do //! ---------------------------- -//! - Enforce size of outbound messages (e.g. Publish) -//! - QUIC via QUINN //! - Even More testing -//! - More documentation -//! - Remove logging calls or move all to test flag //! //! Notes: //! ---------------------------- -//! - Your handler should not wait too long -//! - Create a new connection when an error or disconnect is encountered //! - Handlers only get incoming packets -//! - Sync mode requires a non blocking stream +//! - Create a new connection when an error or disconnect is encountered //! //! Smol example: //! ---------------------------- @@ -48,7 +43,7 @@ //! // To reconnect after a disconnect or error //! let (mut network, client) = NetworkBuilder //! ::new_from_client_id("mqrsttSmolExample") -//! .smol_sequential_network(); +//! .smol_network(); //! let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) //! .await //! .unwrap(); @@ -87,7 +82,7 @@ //! async fn main() { //! let (mut network, client) = NetworkBuilder //! ::new_from_client_id("TokioTcpPingPongExample") -//! .tokio_sequential_network(); +//! .tokio_network(); //! //! // Construct a no op handler //! let mut nop = NOP{}; @@ -111,59 +106,6 @@ //! assert!(n.is_ok()); //! } //! ``` -//! -// //! Sync example: -// //! ---------------------------- -// //! ```rust -// //! use mqrstt::{ -// //! MqttClient, -// //! example_handlers::NOP, -// //! ConnectOptions, -// //! packets::{self, Packet}, -// //! EventHandler, -// //! sync::NetworkStatus, -// //! }; -// //! use std::net::TcpStream; -// //! -// //! let mut client_id: String = "SyncTcppingrespTestExample".to_string(); -// //! let options = ConnectOptions::new(client_id); -// //! -// //! let address = "broker.emqx.io"; -// //! let port = 1883; -// //! -// //! let (mut network, client) = new_sync(options); -// //! -// //! // Construct a no op handler -// //! let mut nop = NOP{}; -// //! -// //! // In normal operations you would want to loop connect -// //! // To reconnect after a disconnect or error -// //! let stream = TcpStream::connect((address, port)).unwrap(); -// //! // IMPORTANT: Set nonblocking to true! No progression will be made when stream reads block! -// //! stream.set_nonblocking(true).unwrap(); -// //! network.connect(stream, &mut nop).unwrap(); -// //! -// //! let res_join_handle = std::thread::spawn(move || -// //! loop { -// //! match network.poll(&mut nop) { -// //! Ok(NetworkStatus::ActivePending) => { -// //! std::thread::sleep(std::time::Duration::from_millis(100)); -// //! }, -// //! Ok(NetworkStatus::ActiveReady) => { -// //! std::thread::sleep(std::time::Duration::from_millis(100)); -// //! }, -// //! otherwise => return otherwise, -// //! } -// //! } -// //! ); -// //! -// //! std::thread::sleep(std::time::Duration::from_secs(30)); -// //! client.disconnect_blocking().unwrap(); -// //! let join_res = res_join_handle.join(); -// //! assert!(join_res.is_ok()); -// //! let res = join_res.unwrap(); -// //! assert!(res.is_ok()); -// //! ``` const CHANNEL_SIZE: usize = 100; @@ -173,16 +115,29 @@ mod connect_options; mod state_handler; mod util; +/// Contains the reader writer parts for the smol runtime. +/// +/// Module [`crate::smol`] only contains a synchronized approach to call the users `Handler`. #[cfg(feature = "smol")] pub mod smol; -#[cfg(any(feature = "tokio"))] +/// Contains the reader and writer parts for the tokio runtime. +/// +/// Module [`crate::tokio`] contains both a synchronized and concurrent approach to call the users `Handler`. +#[cfg(feature = "tokio")] pub mod tokio; +/// Error types that the user can see during operation of the client. +/// +/// Wraps all other errors that can be encountered. pub mod error; + +/// All event handler traits are defined here. +/// +/// Event handlers are used to process incoming packets. mod event_handlers; +/// All MQTT packets are defined here pub mod packets; mod state; -use std::marker::PhantomData; pub use event_handlers::*; @@ -190,6 +145,7 @@ pub use client::MqttClient; pub use connect_options::ConnectOptions; use state_handler::StateHandler; +use std::marker::PhantomData; #[cfg(test)] pub mod tests; @@ -237,14 +193,10 @@ impl NetworkBuilder { #[cfg(feature = "tokio")] impl NetworkBuilder where - H: AsyncEventHandlerMut, - S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, + H: AsyncEventHandler, + S: ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite + Sized + Unpin, { - /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] - /// This network is supposed to be ran on a single task/thread. The read and write operations happen one after the other. - /// This approach does not give the most speed in terms of reading and writing but provides a simple and easy to use client with low overhead for low throughput clients. - /// - /// For more throughput: [`NetworkBuilder::tokio_concurrent_network`] + /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncRead`] and [`::tokio::io::AsyncWrite`] /// /// # Example /// ``` @@ -253,11 +205,11 @@ where /// let options = ConnectOptions::new("ExampleClient"); /// let (mut network, client) = mqrstt::NetworkBuilder::<(), tokio::net::TcpStream> /// ::new_from_options(options) - /// .tokio_sequential_network(); + /// .tokio_network(); /// ``` - pub fn tokio_sequential_network(self) -> (tokio::Network, MqttClient) + pub fn tokio_network(self) -> (tokio::Network, MqttClient) where - H: AsyncEventHandlerMut, + H: AsyncEventHandler, { let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); @@ -273,51 +225,19 @@ where } } -#[cfg(feature = "tokio")] -impl NetworkBuilder -where - H: AsyncEventHandler, - S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, -{ - /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] - /// # Example - /// - /// ``` - /// use mqrstt::ConnectOptions; - /// - /// let options = ConnectOptions::new("ExampleClient"); - /// let (mut network, client) = mqrstt::NetworkBuilder::<(), tokio::net::TcpStream> - /// ::new_from_options(options) - /// .tokio_concurrent_network(); - /// ``` - pub fn tokio_concurrent_network(self) -> (tokio::Network, MqttClient) { - let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); - - let (apkids, apkids_r) = available_packet_ids::AvailablePacketIds::new(self.options.send_maximum()); - - let max_packet_size = self.options.maximum_packet_size(); - - let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); - - let network = tokio::Network::new(self.options, to_network_r, apkids); - - (network, client) - } -} - #[cfg(feature = "smol")] impl NetworkBuilder where - H: AsyncEventHandlerMut, - S: ::smol::io::AsyncReadExt + ::smol::io::AsyncWriteExt + Sized + Unpin, + H: AsyncEventHandler, + S: ::smol::io::AsyncRead + ::smol::io::AsyncWrite + Sized + Unpin, { - /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] + /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncRead`] and [`::tokio::io::AsyncWrite`] /// ``` /// let (mut network, client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream> /// ::new_from_client_id("ExampleClient") - /// .smol_sequential_network(); + /// .smol_network(); /// ``` - pub fn smol_sequential_network(self) -> (smol::Network, MqttClient) { + pub fn smol_network(self) -> (smol::Network, MqttClient) { let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); let (apkids, apkids_r) = available_packet_ids::AvailablePacketIds::new(self.options.send_maximum()); @@ -332,38 +252,6 @@ where } } -#[cfg(feature = "todo")] -/// Creates a new [`sync::Network`] and [`MqttClient`] that can be connected to a broker. -/// S should implement [`std::io::Read`] and [`std::io::Write`]. -/// Additionally, S should be made non_blocking otherwise it will not progress. -/// -/// # Example -/// -/// ``` -/// use mqrstt::ConnectOptions; -/// -/// let options = ConnectOptions::new("ExampleClient"); -/// let (network, client) = mqrstt::new_sync::(options); -/// ``` -pub fn new_sync(options: ConnectOptions) -> (sync::Network, MqttClient) -where - S: std::io::Read + std::io::Write + Sized + Unpin, -{ - use available_packet_ids::AvailablePacketIds; - - let (to_network_s, to_network_r) = async_channel::bounded(100); - - let (apkids, apkids_r) = AvailablePacketIds::new(options.send_maximum()); - - let max_packet_size = options.maximum_packet_size(); - - let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); - - let network = sync::Network::new(options, to_network_r, apkids); - - (network, client) -} - #[cfg(test)] fn random_chars() -> String { rand::Rng::sample_iter(rand::thread_rng(), &rand::distributions::Alphanumeric).take(7).map(char::from).collect() @@ -377,19 +265,19 @@ mod smol_lib_test { use rand::Rng; - use crate::{example_handlers::PingPong, packets::QoS, ConnectOptions, NetworkBuilder}; + use crate::{example_handlers::PingPong, packets::QoS, random_chars, ConnectOptions, NetworkBuilder}; #[test] fn test_smol_tcp() { smol::block_on(async { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); + let mut client_id: String = random_chars(); client_id += "_SmolTcpPingPong"; let options = ConnectOptions::new(client_id); let address = "broker.emqx.io"; let port = 1883; - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingpong = PingPong::new(client.clone()); @@ -426,7 +314,7 @@ mod smol_lib_test { let address = "broker.emqx.io"; let port = 1883; - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); @@ -450,11 +338,11 @@ mod smol_lib_test { ); assert!(n.is_ok()); let pingresp = n.unwrap(); - assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); + assert_eq!(2, pingresp.ping_resp_received); }); } - #[cfg(all(target_family = "windows"))] + #[cfg(target_family = "windows")] #[test] fn test_close_write_tcp_stream_smol() { use crate::error::ConnectionError; @@ -472,7 +360,7 @@ mod smol_lib_test { let (n, _) = futures::join!( async { - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); network.connect(stream, &mut pingresp).await @@ -496,99 +384,46 @@ mod smol_lib_test { #[cfg(feature = "tokio")] #[cfg(test)] mod tokio_lib_test { - use crate::example_handlers::PingPong; - - use crate::packets::QoS; - - use std::{sync::Arc, time::Duration}; - + use crate::example_handlers::PingResp; + use crate::random_chars; use crate::ConnectOptions; - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] - async fn test_tokio_tcp() { - use std::hint::black_box; + use std::time::Duration; - use crate::NetworkBuilder; + #[tokio::test] + async fn test_tokio_ping_req() { + let mut client_id: String = random_chars(); + client_id += "_TokioTcppingrespTest"; + let mut options = ConnectOptions::new(client_id); + let keep_alive_interval = 5; + options.set_keep_alive_interval(Duration::from_secs(keep_alive_interval)); - let client_id: String = crate::random_chars() + "_TokioTcpPingPong"; + let wait_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; - let (mut network, client) = NetworkBuilder::new_from_client_id(client_id).tokio_concurrent_network(); + let (mut network, client) = crate::NetworkBuilder::new_from_options(options).tokio_network(); let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); - let mut pingpong = Arc::new(PingPong::new(client.clone())); - - network.connect(stream, &mut pingpong).await.unwrap(); - - let topic = crate::random_chars() + "_mqrstt"; + let mut pingresp = PingResp::new(client.clone()); - client.subscribe((topic.as_str(), QoS::ExactlyOnce)).await.unwrap(); + network.connect(stream, &mut pingresp).await.unwrap(); - tokio::time::sleep(Duration::from_secs(5)).await; - - let (read, write) = network.split(pingpong.clone()).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_result, write_result, _) = tokio::join!(read_handle, write_handle, async { - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".repeat(500)).await.unwrap(); - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".to_vec()).await.unwrap(); - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".to_vec()).await.unwrap(); - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".repeat(500)).await.unwrap(); - - client.unsubscribe(topic.as_str()).await.unwrap(); - - for _ in 0..30 { - tokio::time::sleep(Duration::from_secs(1)).await; - if pingpong.number.load(std::sync::atomic::Ordering::SeqCst) == 4 { - break; - } - } - - client.disconnect().await.unwrap(); + let network_handle = tokio::task::spawn(async move { + let _result = network.run(&mut pingresp).await; + // check result and or restart the connection + pingresp }); - let write_result = write_result.unwrap(); - assert!(write_result.is_ok()); - assert_eq!(crate::NetworkStatus::OutgoingDisconnect, write_result.unwrap()); - assert_eq!(4, pingpong.number.load(std::sync::atomic::Ordering::SeqCst)); - let _ = black_box(read_result); - } - - // #[tokio::test] - // async fn test_tokio_ping_req() { - // let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - // client_id += "_TokioTcppingrespTest"; - // let mut options = ConnectOptions::new(client_id); - // let keep_alive_interval = 5; - // options.set_keep_alive_interval(Duration::from_secs(keep_alive_interval)); + tokio::time::sleep(wait_duration).await; + client.disconnect().await.unwrap(); - // let wait_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; + tokio::time::sleep(Duration::from_secs(1)).await; - // let (mut network, client) = new_tokio(options); - - // let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); - - // let pingresp = Arc::new(crate::test_handlers::PingResp::new(client.clone())); - - // network.connect(stream, &mut pingresp).await.unwrap(); - - // let (read, write) = network.split(pingresp.clone()).unwrap(); - - // let read_handle = tokio::task::spawn(read.run()); - // let write_handle = tokio::task::spawn(write.run()); - - // tokio::time::sleep(wait_duration).await; - // client.disconnect().await.unwrap(); - - // tokio::time::sleep(Duration::from_secs(1)).await; - - // let (read_result, write_result) = tokio::join!(read_handle, write_handle); - // let (read_result, write_result) = (read_result.unwrap(), write_result.unwrap()); - // assert!(write_result.is_ok()); - // assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); - // } + let result = network_handle.await; + assert!(result.is_ok()); + let result = result.unwrap(); + assert_eq!(2, result.ping_resp_received); + } #[cfg(all(feature = "tokio", target_family = "windows"))] #[tokio::test] @@ -600,11 +435,11 @@ mod tokio_lib_test { let address = ("127.0.0.1", 2000); let client_id: String = crate::random_chars() + "_TokioTcppingrespTest"; - let options = ConnectOptions::new(client_id); + let options = crate::ConnectOptions::new(client_id); let (n, _) = tokio::join!( async move { - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_network(); let stream = tokio::net::TcpStream::connect(address).await.unwrap(); @@ -621,8 +456,7 @@ mod tokio_lib_test { ); if let ConnectionError::Io(err) = n.unwrap_err() { - assert_eq!(ErrorKind::ConnectionReset, err.kind()); - assert_eq!("Connection reset by peer".to_string(), err.to_string()); + assert_eq!(ErrorKind::UnexpectedEof, err.kind()); } else { panic!(); } diff --git a/mqrstt/src/packets/auth.rs b/mqrstt/src/packets/auth.rs deleted file mode 100644 index dc6ba68..0000000 --- a/mqrstt/src/packets/auth.rs +++ /dev/null @@ -1,149 +0,0 @@ -use bytes::Bytes; - -use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::AuthReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Auth { - pub reason_code: AuthReasonCode, - pub properties: AuthProperties, -} - -impl VariableHeaderRead for Auth { - fn read(_: u8, _: usize, mut buf: Bytes) -> Result { - let reason_code = AuthReasonCode::read(&mut buf)?; - let properties = AuthProperties::read(&mut buf)?; - - Ok(Self { reason_code, properties }) - } -} - -impl VariableHeaderWrite for Auth { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - self.reason_code.write(buf)?; - self.properties.write(buf)?; - Ok(()) - } -} - -impl WireLength for Auth { - fn wire_len(&self) -> usize { - 1 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len() - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct AuthProperties { - /// 3.15.2.2.2 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method. - pub authentication_method: Option>, - - /// 3.15.2.2.3 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - pub authentication_data: Bytes, - - /// 3.15.2.2.4 Reason String - /// 31 (0x1F) Byte, Identifier of the Reason String - pub reason_string: Option>, - - /// 3.15.2.2.5 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for AuthProperties { - fn read(buf: &mut Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = AuthProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::MalformedPacket); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut property_data)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.reason_string = Some(Box::::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::AuthenticationMethod => { - if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); - } - properties.authentication_method = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationData => { - if properties.authentication_data.is_empty() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); - } - properties.authentication_data = Bytes::read(&mut property_data)?; - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Auth)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for AuthProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(authentication_method) = &self.authentication_method { - PropertyType::AuthenticationMethod.write(buf)?; - authentication_method.write(buf)?; - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - PropertyType::AuthenticationData.write(buf)?; - self.authentication_data.write(buf)?; - } - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - - Ok(()) - } -} - -impl WireLength for AuthProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(authentication_method) = &self.authentication_method { - len += authentication_method.wire_len(); - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - len += self.authentication_data.wire_len(); - } - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len(); - } - for (key, value) in &self.user_properties { - len += key.wire_len() + value.wire_len(); - } - len - } -} diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs new file mode 100644 index 0000000..6f55d4d --- /dev/null +++ b/mqrstt/src/packets/auth/mod.rs @@ -0,0 +1,69 @@ +mod properties; + +pub use properties::AuthProperties; +mod reason_code; +pub use reason_code::AuthReasonCode; + +use bytes::Bytes; + +use super::{ + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +/// The AUTH packet is used to perform more intriquite authentication methods. +/// +/// At the time of writing this client does not (yet) provide the user a method of handling the auth handshake. +/// There are several other ways to perform authentication, for example using TLS. +/// Additionally, not many clients support this packet fully. +pub struct Auth { + pub reason_code: AuthReasonCode, + pub properties: AuthProperties, +} + +impl PacketRead for Auth { + fn read(_: u8, _: usize, mut buf: Bytes) -> Result { + let reason_code = AuthReasonCode::read(&mut buf)?; + let properties = AuthProperties::read(&mut buf)?; + + Ok(Self { reason_code, properties }) + } +} + +impl PacketAsyncRead for Auth +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let (reason_code, reason_code_read_bytes) = AuthReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = AuthProperties::async_read(stream).await?; + + Ok((Self { reason_code, properties }, reason_code_read_bytes + properties_read_bytes)) + } +} + +impl crate::packets::mqtt_trait::PacketAsyncWrite for Auth +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let reason_code_written = self.reason_code.async_write(stream).await?; + let properties_written = self.properties.async_write(stream).await?; + Ok(reason_code_written + properties_written) + } +} + +impl PacketWrite for Auth { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + self.reason_code.write(buf)?; + self.properties.write(buf)?; + Ok(()) + } +} + +impl WireLength for Auth { + fn wire_len(&self) -> usize { + 1 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + } +} diff --git a/mqrstt/src/packets/auth/properties.rs b/mqrstt/src/packets/auth/properties.rs new file mode 100644 index 0000000..e905897 --- /dev/null +++ b/mqrstt/src/packets/auth/properties.rs @@ -0,0 +1,91 @@ +use bytes::Bytes; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!( + /// Properties of the AUTH packet + AuthProperties, + AuthenticationMethod, + AuthenticationData, + ReasonString, + UserProperty +); + +impl MqttRead for AuthProperties { + fn read(buf: &mut Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = AuthProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::MalformedPacket); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut property_data)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.reason_string = Some(Box::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::AuthenticationMethod => { + if properties.authentication_method.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); + } + properties.authentication_method = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationData => { + if properties.authentication_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); + } + properties.authentication_data = Some(Vec::::read(&mut property_data)?); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Auth)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for AuthProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(authentication_method) = &self.authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if let Some(authentication_data) = &self.authentication_data { + if !authentication_data.is_empty() && self.authentication_method.is_some() { + PropertyType::AuthenticationData.write(buf)?; + authentication_data.write(buf)?; + } + } + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + + Ok(()) + } +} diff --git a/mqrstt/src/packets/auth/reason_code.rs b/mqrstt/src/packets/auth/reason_code.rs new file mode 100644 index 0000000..364e1aa --- /dev/null +++ b/mqrstt/src/packets/auth/reason_code.rs @@ -0,0 +1,5 @@ +crate::packets::macros::reason_code!(AuthReasonCode, + Success, + ContinueAuthentication, + ReAuthenticate +); \ No newline at end of file diff --git a/mqrstt/src/packets/connack.rs b/mqrstt/src/packets/connack.rs deleted file mode 100644 index fb198ab..0000000 --- a/mqrstt/src/packets/connack.rs +++ /dev/null @@ -1,527 +0,0 @@ -use super::{ - error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::ConnAckReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, -}; -use bytes::{Buf, BufMut, Bytes}; - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct ConnAck { - /// 3.2.2.1 Connect Acknowledge Flags - pub connack_flags: ConnAckFlags, - - /// 3.2.2.2 Connect Reason Code - /// Byte 2 in the Variable Header is the Connect Reason Code. - pub reason_code: ConnAckReasonCode, - - /// 3.2.2.3 CONNACK Properties - pub connack_properties: ConnAckProperties, -} - -impl VariableHeaderRead for ConnAck { - fn read(_: u8, header_len: usize, mut buf: bytes::Bytes) -> Result { - if header_len > buf.len() { - return Err(DeserializeError::InsufficientData("ConnAck".to_string(), buf.len(), header_len)); - } - - let connack_flags = ConnAckFlags::read(&mut buf)?; - let reason_code = ConnAckReasonCode::read(&mut buf)?; - let connack_properties = ConnAckProperties::read(&mut buf)?; - - Ok(Self { - connack_flags, - reason_code, - connack_properties, - }) - } -} - -impl VariableHeaderWrite for ConnAck { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - self.connack_flags.write(buf)?; - self.reason_code.write(buf)?; - self.connack_properties.write(buf)?; - - Ok(()) - } -} - -impl WireLength for ConnAck { - fn wire_len(&self) -> usize { - 2 + // 1 for connack_flags and 1 for reason_code - variable_integer_len(self.connack_properties.wire_len()) + - self.connack_properties.wire_len() - } -} - -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ConnAckProperties { - /// 3.2.2.3.2 Session Expiry Interval - /// 17 (0x11) Byte Identifier of the Session Expiry Interval - pub session_expiry_interval: Option, - - /// 3.2.2.3.3 Receive Maximum - /// 33 (0x21) Byte, Identifier of the Receive Maximum - pub receive_maximum: Option, - - /// 3.2.2.3.4 Maximum QoS - /// 36 (0x24) Byte, Identifier of the Maximum QoS. - pub maximum_qos: Option, - - /// 3.2.2.3.5 Retain Available - /// 37 (0x25) Byte, Identifier of Retain Available. - pub retain_available: Option, - - /// 3.2.2.3.6 Maximum Packet Size - /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. - pub maximum_packet_size: Option, - - /// 3.2.2.3.7 Assigned Client Identifier - /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. - pub assigned_client_id: Option>, - - /// 3.2.2.3.8 Topic Alias Maximum - /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. - pub topic_alias_maximum: Option, - - /// 3.2.2.3.9 Reason String - /// 31 (0x1F) Byte Identifier of the Reason String. - pub reason_string: Option>, - - /// 3.2.2.3.10 User Property - /// 38 (0x26) Byte, Identifier of User Property. - pub user_properties: Vec<(Box, Box)>, - - /// 3.2.2.3.11 Wildcard Subscription Available - /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. - pub wildcards_available: Option, - - /// 3.2.2.3.12 Subscription Identifiers Available - /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. - pub subscription_ids_available: Option, - - /// 3.2.2.3.13 Shared Subscription Available - /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. - pub shared_subscription_available: Option, - - /// 3.2.2.3.14 Server Keep Alive - /// 19 (0x13) Byte, Identifier of the Server Keep Alive - pub server_keep_alive: Option, - - /// 3.2.2.3.15 Response Information - /// 26 (0x1A) Byte, Identifier of the Response Information. - pub response_info: Option>, - - /// 3.2.2.3.16 Server Reference - /// 28 (0x1C) Byte, Identifier of the Server Reference - pub server_reference: Option>, - - /// 3.2.2.3.17 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method - pub authentication_method: Option>, - - /// 3.2.2.3.18 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - // There is a small inconsistency here with authentication_data in the connect packet. - // This is Option while that type uses just Bytes. - pub authentication_data: Option, -} - -impl MqttRead for ConnAckProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf).map_err(DeserializeError::from)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("ConnAckProperties".to_string(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - let property = PropertyType::read(&mut property_data)?; - match property { - PropertyType::SessionExpiryInterval => { - if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.session_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::ReceiveMaximum => { - if properties.receive_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); - } - properties.receive_maximum = Some(u16::read(&mut property_data)?); - } - PropertyType::MaximumQos => { - if properties.maximum_qos.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumQos)); - } - properties.maximum_qos = Some(QoS::read(&mut property_data)?); - } - PropertyType::RetainAvailable => { - if properties.retain_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable)); - } - properties.retain_available = Some(bool::read(&mut property_data)?); - } - PropertyType::MaximumPacketSize => { - if properties.maximum_packet_size.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); - } - properties.maximum_packet_size = Some(u32::read(&mut property_data)?); - } - PropertyType::AssignedClientIdentifier => { - if properties.assigned_client_id.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier)); - } - properties.assigned_client_id = Some(Box::::read(&mut property_data)?); - } - PropertyType::TopicAliasMaximum => { - if properties.topic_alias_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); - } - properties.topic_alias_maximum = Some(u16::read(&mut property_data)?); - } - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::WildcardSubscriptionAvailable => { - if properties.wildcards_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable)); - } - properties.wildcards_available = Some(bool::read(&mut property_data)?); - } - PropertyType::SubscriptionIdentifierAvailable => { - if properties.subscription_ids_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable)); - } - properties.subscription_ids_available = Some(bool::read(&mut property_data)?); - } - PropertyType::SharedSubscriptionAvailable => { - if properties.shared_subscription_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable)); - } - properties.shared_subscription_available = Some(bool::read(&mut property_data)?); - } - PropertyType::ServerKeepAlive => { - if properties.server_keep_alive.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive)); - } - properties.server_keep_alive = Some(u16::read(&mut property_data)?); - } - PropertyType::ResponseInformation => { - if properties.response_info.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation)); - } - properties.response_info = Some(Box::::read(&mut property_data)?); - } - PropertyType::ServerReference => { - if properties.server_reference.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); - } - properties.server_reference = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationMethod => { - if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); - } - properties.authentication_method = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationData => { - if properties.authentication_data.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); - } - properties.authentication_data = Some(Bytes::read(&mut property_data)?); - } - - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::ConnAck)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for ConnAckProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - let Self { - session_expiry_interval, - receive_maximum, - maximum_qos, - retain_available, - maximum_packet_size, - assigned_client_id, - topic_alias_maximum, - reason_string, - user_properties, - wildcards_available, - subscription_ids_available, - shared_subscription_available, - server_keep_alive, - response_info, - server_reference, - authentication_method, - authentication_data, - } = self; - - if let Some(session_expiry_interval) = session_expiry_interval { - PropertyType::SessionExpiryInterval.write(buf)?; - buf.put_u32(*session_expiry_interval); - } - if let Some(receive_maximum) = receive_maximum { - PropertyType::ReceiveMaximum.write(buf)?; - buf.put_u16(*receive_maximum); - } - if let Some(maximum_qos) = maximum_qos { - PropertyType::MaximumQos.write(buf)?; - maximum_qos.write(buf)?; - } - if let Some(retain_available) = retain_available { - PropertyType::RetainAvailable.write(buf)?; - retain_available.write(buf)?; - } - if let Some(maximum_packet_size) = maximum_packet_size { - PropertyType::MaximumPacketSize.write(buf)?; - buf.put_u32(*maximum_packet_size); - } - if let Some(client_id) = assigned_client_id { - PropertyType::AssignedClientIdentifier.write(buf)?; - client_id.write(buf)?; - } - if let Some(topic_alias_maximum) = topic_alias_maximum { - PropertyType::TopicAliasMaximum.write(buf)?; - buf.put_u16(*topic_alias_maximum); - } - if let Some(reason_string) = reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, val) in user_properties.iter() { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - val.write(buf)?; - } - if let Some(wildcards_available) = wildcards_available { - PropertyType::WildcardSubscriptionAvailable.write(buf)?; - wildcards_available.write(buf)?; - } - if let Some(subscription_ids_available) = subscription_ids_available { - PropertyType::SubscriptionIdentifierAvailable.write(buf)?; - subscription_ids_available.write(buf)?; - } - if let Some(shared_subscription_available) = shared_subscription_available { - PropertyType::SharedSubscriptionAvailable.write(buf)?; - shared_subscription_available.write(buf)?; - } - if let Some(server_keep_alive) = server_keep_alive { - PropertyType::ServerKeepAlive.write(buf)?; - server_keep_alive.write(buf)?; - } - if let Some(response_info) = response_info { - PropertyType::ResponseInformation.write(buf)?; - response_info.write(buf)?; - } - if let Some(server_reference) = server_reference { - PropertyType::ServerReference.write(buf)?; - server_reference.write(buf)?; - } - if let Some(authentication_method) = &authentication_method { - PropertyType::AuthenticationMethod.write(buf)?; - authentication_method.write(buf)?; - } - if let Some(authentication_data) = authentication_data { - if authentication_method.is_none() { - return Err(SerializeError::AuthDataWithoutAuthMethod); - } - PropertyType::AuthenticationData.write(buf)?; - authentication_data.write(buf)?; - } - - Ok(()) - } -} - -impl WireLength for ConnAckProperties { - fn wire_len(&self) -> usize { - let mut len: usize = 0; - - if self.session_expiry_interval.is_some() { - len += 1 + 4; - } - if self.receive_maximum.is_some() { - len += 1 + 2; - } - if self.maximum_qos.is_some() { - len += 1 + 1; - } - if self.retain_available.is_some() { - len += 1 + 1; - } - if self.maximum_packet_size.is_some() { - len += 1 + 4; - } - if let Some(client_id) = &self.assigned_client_id { - len += 1 + client_id.wire_len(); - } - if self.topic_alias_maximum.is_some() { - len += 1 + 2; - } - if let Some(reason_string) = &self.reason_string { - len += 1 + reason_string.wire_len(); - } - for (key, value) in &self.user_properties { - len += 1; - len += key.wire_len(); - len += value.wire_len(); - } - if self.wildcards_available.is_some() { - len += 1 + 1; - } - if self.subscription_ids_available.is_some() { - len += 1 + 1; - } - if self.shared_subscription_available.is_some() { - len += 1 + 1; - } - if self.server_keep_alive.is_some() { - len += 1 + 2; - } - if let Some(response_info) = &self.response_info { - len += 1 + response_info.wire_len(); - } - if let Some(server_reference) = &self.server_reference { - len += 1 + server_reference.wire_len(); - } - if let Some(authentication_method) = &self.authentication_method { - len += 1 + authentication_method.wire_len(); - } - if self.authentication_data.is_some() && self.authentication_method.is_some() { - len += 1 + self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); - } - - len - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct ConnAckFlags { - pub session_present: bool, -} - -impl MqttRead for ConnAckFlags { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("ConnAckFlags".to_string(), 0, 1)); - } - - let byte = buf.get_u8(); - - Ok(Self { - session_present: (byte & 0b00000001) == 0b00000001, - }) - } -} - -impl MqttWrite for ConnAckFlags { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let byte = self.session_present as u8; - - buf.put_u8(byte); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - - use crate::packets::{ - connack::{ConnAck, ConnAckProperties}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite}, - reason_codes::ConnAckReasonCode, - Packet, - }; - - #[test] - fn read_write_connack_packet() { - let c = ConnAck { ..Default::default() }; - - let p1 = Packet::ConnAck(c); - let mut buf = bytes::BytesMut::new(); - - p1.write(&mut buf).unwrap(); - - let p2 = Packet::read_from_buffer(&mut buf).unwrap(); - - assert_eq!(p1, p2); - } - - #[test] - fn read_write_connack() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - 0x01, // Connack flags - 0x00, // Reason code, - 0x00, // empty properties - ]; - - buf.extend_from_slice(packet); - let c1 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); - - assert_eq!(ConnAckReasonCode::Success, c1.reason_code); - assert_eq!(ConnAckProperties::default(), c1.connack_properties); - - let mut buf = bytes::BytesMut::new(); - - c1.write(&mut buf).unwrap(); - - let c2 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); - - assert_eq!(c1, c2) - } - - #[test] - fn read_write_connack_properties() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - 56, // ConnAckProperties variable length - 17, // session_expiry_interval - 0xff, 0xff, 37, // retain_available - 0x1, // true - 18, // Assigned Client Id - 0, 11, // 11 bytes "KeanuReeves" without space - b'K', b'e', b'a', b'n', b'u', b'R', b'e', b'e', b'v', b'e', b's', 36, // Max QoS - 2, // QoS 2 Exactly Once - 34, // Topic Alias Max = 255 - 0, 255, 31, // Reason String = 'Houston we have got a problem' - 0, 29, b'H', b'o', b'u', b's', b't', b'o', b'n', b' ', b'w', b'e', b' ', b'h', b'a', b'v', b'e', b' ', b'g', b'o', b't', b' ', b'a', b' ', b'p', b'r', b'o', b'b', b'l', b'e', b'm', - ]; - - buf.extend_from_slice(packet); - let c1 = ConnAckProperties::read(&mut buf.into()).unwrap(); - - let mut buf = bytes::BytesMut::new(); - - c1.write(&mut buf).unwrap(); - - let _buf_clone = buf.to_vec(); - - let c2 = ConnAckProperties::read(&mut buf.into()).unwrap(); - - assert_eq!(c1, c2); - } -} diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs new file mode 100644 index 0000000..90dc715 --- /dev/null +++ b/mqrstt/src/packets/connack/mod.rs @@ -0,0 +1,274 @@ +mod properties; +pub use properties::ConnAckProperties; + +mod reason_code; +pub use reason_code::ConnAckReasonCode; + +use super::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, +}; +use bytes::{Buf, BufMut}; +use tokio::io::AsyncReadExt; + +/// ConnAck packet is sent by the server in response to a [`crate::packets::Connect`] packet. +/// +/// The ConnAck packet contains the values used by the server related to this connection. +/// +/// For example the requested client identifier can be changed by the server. +/// This is then indicated using the property [`crate::packets::ConnAckProperties::assigned_client_identifier`]. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct ConnAck { + /// 3.2.2.1 Connect Acknowledge Flags + pub connack_flags: ConnAckFlags, + + /// 3.2.2.2 Connect Reason Code + /// Byte 2 in the Variable Header is the Connect Reason Code. + pub reason_code: ConnAckReasonCode, + + /// 3.2.2.3 CONNACK Properties + pub connack_properties: ConnAckProperties, +} + +impl PacketRead for ConnAck { + fn read(_: u8, header_len: usize, mut buf: bytes::Bytes) -> Result { + if header_len > buf.len() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), header_len)); + } + + let connack_flags = ConnAckFlags::read(&mut buf)?; + let reason_code = ConnAckReasonCode::read(&mut buf)?; + let connack_properties = ConnAckProperties::read(&mut buf)?; + + Ok(Self { + connack_flags, + reason_code, + connack_properties, + }) + } +} + +impl PacketAsyncRead for ConnAck +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; + let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; + let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; + + Ok(( + Self { + connack_flags, + reason_code, + connack_properties, + }, + read_bytes + reason_code_read_bytes + connack_properties_read_bytes, + )) + } +} + +impl PacketWrite for ConnAck { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + self.connack_flags.write(buf)?; + self.reason_code.write(buf)?; + self.connack_properties.write(buf)?; + + Ok(()) + } +} + +impl crate::packets::mqtt_trait::PacketAsyncWrite for ConnAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + use crate::packets::mqtt_trait::MqttAsyncWrite; + let connack_flags_written = self.connack_flags.async_write(stream).await?; + let reason_code_written = self.reason_code.async_write(stream).await?; + let connack_properties_written = self.connack_properties.async_write(stream).await?; + + Ok(connack_flags_written + reason_code_written + connack_properties_written) + } +} + +impl WireLength for ConnAck { + fn wire_len(&self) -> usize { + 2 + // 1 for connack_flags and 1 for reason_code + self.connack_properties.wire_len().variable_integer_len() + + self.connack_properties.wire_len() + } +} + +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct ConnAckFlags { + pub session_present: bool, +} + +impl MqttAsyncRead for ConnAckFlags +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let byte = stream.read_u8().await?; + Ok(( + Self { + session_present: (byte & 0b00000001) == 0b00000001, + }, + 1, + )) + } +} + +impl MqttRead for ConnAckFlags { + fn read(buf: &mut bytes::Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + let byte = buf.get_u8(); + + Ok(Self { + session_present: (byte & 0b00000001) == 0b00000001, + }) + } +} + +impl MqttWrite for ConnAckFlags { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + let byte = self.session_present as u8; + + buf.put_u8(byte); + Ok(()) + } +} + +impl crate::packets::mqtt_trait::MqttAsyncWrite for ConnAckFlags +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let byte = self.session_present as u8; + + stream.write_u8(byte).await?; + Ok(1) + } +} + +#[cfg(test)] +mod tests { + + use crate::packets::{ + connack::{ConnAck, ConnAckProperties}, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + ConnAckReasonCode, Packet, VariableInteger, + }; + + #[test] + fn test_wire_len() { + let mut buf = bytes::BytesMut::new(); + + let connack_properties = ConnAckProperties { + session_expiry_interval: Some(60), // Session expiry interval in seconds + receive_maximum: Some(20), // Maximum number of QoS 1 and QoS 2 publications that the client is willing to process concurrently + maximum_qos: Some(crate::packets::QoS::AtMostOnce), // Maximum QoS level supported by the server + retain_available: Some(true), // Whether the server supports retained messages + maximum_packet_size: Some(1024), // Maximum packet size the server is willing to accept + assigned_client_identifier: Some(Box::from("client-12345")), // Client identifier assigned by the server + topic_alias_maximum: Some(10), // Maximum number of topic aliases supported by the server + reason_string: Some(Box::from("Connection accepted")), // Reason string for the connection acknowledgment + user_properties: vec![(Box::from("key1"), Box::from("value1"))], // User property key-value pair + wildcards_available: Some(true), // Whether wildcard subscriptions are available + subscription_ids_available: Some(true), // Whether subscription identifiers are available + shared_subscription_available: Some(true), // Whether shared subscriptions are available + server_keep_alive: Some(120), // Server keep alive time in seconds + response_info: Some(Box::from("Response info")), // Response information + server_reference: Some(Box::from("server-reference")), // Server reference + authentication_method: Some(Box::from("auth-method")), // Authentication method + authentication_data: Some(vec![1, 2, 3, 4]), // Authentication data + }; + + let len = connack_properties.wire_len(); + // determine length of variable integer + let len_of_wire_len = len.write_variable_integer(&mut buf).unwrap(); + // clear buffer before writing actual properties + buf.clear(); + connack_properties.write(&mut buf).unwrap(); + + assert_eq!(len + len_of_wire_len, buf.len()); + } + + #[test] + fn read_write_connack_packet() { + let c = ConnAck { ..Default::default() }; + + let p1 = Packet::ConnAck(c); + let mut buf = bytes::BytesMut::new(); + + p1.write(&mut buf).unwrap(); + + let p2 = Packet::read(&mut buf).unwrap(); + + assert_eq!(p1, p2); + } + + #[test] + fn read_write_connack() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + 0x01, // Connack flags + 0x00, // Reason code, + 0x00, // empty properties + ]; + + buf.extend_from_slice(packet); + let c1 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); + + assert_eq!(ConnAckReasonCode::Success, c1.reason_code); + assert_eq!(ConnAckProperties::default(), c1.connack_properties); + + let mut buf = bytes::BytesMut::new(); + + c1.write(&mut buf).unwrap(); + + let c2 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); + + assert_eq!(c1, c2) + } + + #[test] + fn read_write_connack_properties() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + 56, // ConnAckProperties variable length + 17, // session_expiry_interval + 0xff, 0xff, 37, // retain_available + 0x1, // true + 18, // Assigned Client Id + 0, 11, // 11 bytes "KeanuReeves" without space + b'K', b'e', b'a', b'n', b'u', b'R', b'e', b'e', b'v', b'e', b's', 36, // Max QoS + 2, // QoS 2 Exactly Once + 34, // Topic Alias Max = 255 + 0, 255, 31, // Reason String = 'Houston we have got a problem' + 0, 29, b'H', b'o', b'u', b's', b't', b'o', b'n', b' ', b'w', b'e', b' ', b'h', b'a', b'v', b'e', b' ', b'g', b'o', b't', b' ', b'a', b' ', b'p', b'r', b'o', b'b', b'l', b'e', b'm', + ]; + + buf.extend_from_slice(packet); + let c1 = ConnAckProperties::read(&mut buf.into()).unwrap(); + + let mut buf = bytes::BytesMut::new(); + + let variable_length = c1.wire_len(); + assert_eq!(variable_length, 56); + + c1.write(&mut buf).unwrap(); + + let _buf_clone = buf.to_vec(); + + let c2 = ConnAckProperties::read(&mut buf.into()).unwrap(); + + assert_eq!(c1, c2); + } +} diff --git a/mqrstt/src/packets/connack/properties.rs b/mqrstt/src/packets/connack/properties.rs new file mode 100644 index 0000000..670b23b --- /dev/null +++ b/mqrstt/src/packets/connack/properties.rs @@ -0,0 +1,255 @@ +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, QoS, VariableInteger, +}; +use bytes::BufMut; + +crate::packets::macros::define_properties!( + /// ConnAck Properties + ConnAckProperties, + SessionExpiryInterval, + ReceiveMaximum, + MaximumQos, + RetainAvailable, + MaximumPacketSize, + AssignedClientIdentifier, + TopicAliasMaximum, + ReasonString, + UserProperty, + WildcardSubscriptionAvailable, + SubscriptionIdentifierAvailable, + SharedSubscriptionAvailable, + ServerKeepAlive, + ResponseInformation, + ServerReference, + AuthenticationMethod, + AuthenticationData +); + +impl MqttRead for ConnAckProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + let property = PropertyType::read(&mut property_data)?; + match property { + PropertyType::SessionExpiryInterval => { + if properties.session_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.session_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::ReceiveMaximum => { + if properties.receive_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); + } + properties.receive_maximum = Some(u16::read(&mut property_data)?); + } + PropertyType::MaximumQos => { + if properties.maximum_qos.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumQos)); + } + properties.maximum_qos = Some(QoS::read(&mut property_data)?); + } + PropertyType::RetainAvailable => { + if properties.retain_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable)); + } + properties.retain_available = Some(bool::read(&mut property_data)?); + } + PropertyType::MaximumPacketSize => { + if properties.maximum_packet_size.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); + } + properties.maximum_packet_size = Some(u32::read(&mut property_data)?); + } + PropertyType::AssignedClientIdentifier => { + if properties.assigned_client_identifier.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier)); + } + properties.assigned_client_identifier = Some(Box::::read(&mut property_data)?); + } + PropertyType::TopicAliasMaximum => { + if properties.topic_alias_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); + } + properties.topic_alias_maximum = Some(u16::read(&mut property_data)?); + } + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::WildcardSubscriptionAvailable => { + if properties.wildcards_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable)); + } + properties.wildcards_available = Some(bool::read(&mut property_data)?); + } + PropertyType::SubscriptionIdentifierAvailable => { + if properties.subscription_ids_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable)); + } + properties.subscription_ids_available = Some(bool::read(&mut property_data)?); + } + PropertyType::SharedSubscriptionAvailable => { + if properties.shared_subscription_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable)); + } + properties.shared_subscription_available = Some(bool::read(&mut property_data)?); + } + PropertyType::ServerKeepAlive => { + if properties.server_keep_alive.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive)); + } + properties.server_keep_alive = Some(u16::read(&mut property_data)?); + } + PropertyType::ResponseInformation => { + if properties.response_info.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation)); + } + properties.response_info = Some(Box::::read(&mut property_data)?); + } + PropertyType::ServerReference => { + if properties.server_reference.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); + } + properties.server_reference = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationMethod => { + if properties.authentication_method.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); + } + properties.authentication_method = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationData => { + if properties.authentication_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); + } + properties.authentication_data = Some(Vec::::read(&mut property_data)?); + } + + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::ConnAck)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for ConnAckProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + let Self { + session_expiry_interval, + receive_maximum, + maximum_qos, + retain_available, + maximum_packet_size, + assigned_client_identifier, + topic_alias_maximum, + reason_string, + user_properties, + wildcards_available, + subscription_ids_available, + shared_subscription_available, + server_keep_alive, + response_info, + server_reference, + authentication_method, + authentication_data, + } = self; + + if let Some(session_expiry_interval) = session_expiry_interval { + PropertyType::SessionExpiryInterval.write(buf)?; + buf.put_u32(*session_expiry_interval); + } + if let Some(receive_maximum) = receive_maximum { + PropertyType::ReceiveMaximum.write(buf)?; + buf.put_u16(*receive_maximum); + } + if let Some(maximum_qos) = maximum_qos { + PropertyType::MaximumQos.write(buf)?; + maximum_qos.write(buf)?; + } + if let Some(retain_available) = retain_available { + PropertyType::RetainAvailable.write(buf)?; + retain_available.write(buf)?; + } + if let Some(maximum_packet_size) = maximum_packet_size { + PropertyType::MaximumPacketSize.write(buf)?; + buf.put_u32(*maximum_packet_size); + } + if let Some(client_id) = assigned_client_identifier { + PropertyType::AssignedClientIdentifier.write(buf)?; + client_id.write(buf)?; + } + if let Some(topic_alias_maximum) = topic_alias_maximum { + PropertyType::TopicAliasMaximum.write(buf)?; + buf.put_u16(*topic_alias_maximum); + } + if let Some(reason_string) = reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, val) in user_properties.iter() { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + val.write(buf)?; + } + if let Some(wildcards_available) = wildcards_available { + PropertyType::WildcardSubscriptionAvailable.write(buf)?; + wildcards_available.write(buf)?; + } + if let Some(subscription_ids_available) = subscription_ids_available { + PropertyType::SubscriptionIdentifierAvailable.write(buf)?; + subscription_ids_available.write(buf)?; + } + if let Some(shared_subscription_available) = shared_subscription_available { + PropertyType::SharedSubscriptionAvailable.write(buf)?; + shared_subscription_available.write(buf)?; + } + if let Some(server_keep_alive) = server_keep_alive { + PropertyType::ServerKeepAlive.write(buf)?; + server_keep_alive.write(buf)?; + } + if let Some(response_info) = response_info { + PropertyType::ResponseInformation.write(buf)?; + response_info.write(buf)?; + } + if let Some(server_reference) = server_reference { + PropertyType::ServerReference.write(buf)?; + server_reference.write(buf)?; + } + if let Some(authentication_method) = &authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if let Some(authentication_data) = authentication_data { + if authentication_method.is_none() { + return Err(SerializeError::AuthDataWithoutAuthMethod); + } + PropertyType::AuthenticationData.write(buf)?; + authentication_data.write(buf)?; + } + + Ok(()) + } +} diff --git a/mqrstt/src/packets/connack/reason_code.rs b/mqrstt/src/packets/connack/reason_code.rs new file mode 100644 index 0000000..51edabe --- /dev/null +++ b/mqrstt/src/packets/connack/reason_code.rs @@ -0,0 +1,24 @@ +crate::packets::macros::reason_code!(ConnAckReasonCode, + Success, + UnspecifiedError, + MalformedPacket, + ProtocolError, + ImplementationSpecificError, + UnsupportedProtocolVersion, + ClientIdentifierNotValid, + BadUsernameOrPassword, + NotAuthorized, + ServerUnavailable, + ServerBusy, + Banned, + BadAuthenticationMethod, + TopicNameInvalid, + PacketTooLarge, + QuotaExceeded, + PayloadFormatInvalid, + RetainNotSupported, + QosNotSupported, + UseAnotherServer, + ServerMoved, + ConnectionRateExceeded +); \ No newline at end of file diff --git a/mqrstt/src/packets/connect.rs b/mqrstt/src/packets/connect.rs deleted file mode 100644 index 2ada053..0000000 --- a/mqrstt/src/packets/connect.rs +++ /dev/null @@ -1,892 +0,0 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::{ - error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, ProtocolVersion, QoS, WireLength, -}; - -/// Variable connect header: -/// -/// -/// ╔═══════════╦═══════════════════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╗ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ ║ Description ║ 7 ║ 6 ║ 5 ║ 4 ║ 3 ║ 2 ║ 1 ║ 0 ║ -/// ╠═══════════╩═══════════════════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╣ -/// ║ ║ -/// ║ Protocol Name ║ -/// ╠═══════════╦═══════════════════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 1 ║ Length MSB (0) ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 2 ║ Length LSB (4) ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 1 ║ 0 ║ 0 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 3 ║ ‘M’ ║ 0 ║ 1 ║ 0 ║ 0 ║ 1 ║ 1 ║ 0 ║ 1 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 4 ║ ‘Q’ ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 0 ║ 0 ║ 1 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 5 ║ ‘T’ ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 0 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 6 ║ ‘T’ ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 0 ║ -/// ╚═══════════╩═══════════════════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╝ -/// -/// Byte 7: -/// The protocol version -/// -/// Byte 8: -/// 3.1.2.3 Connect Flags : -/// ╔═════╦═══════════╦══════════╦═════════════╦═════╦════╦═══════════╦═════════════╦══════════╗ -/// ║ Bit ║ 7 ║ 6 ║ 5 ║ 4 ║ 3 ║ 2 ║ 1 ║ 0 ║ -/// ╠═════╬═══════════╬══════════╬═════════════╬═════╩════╬═══════════╬═════════════╬══════════╣ -/// ║ ║ User Name ║ Password ║ Will Retain ║ Will QoS ║ Will Flag ║ Clean Start ║ Reserved ║ -/// ╚═════╩═══════════╩══════════╩═════════════╩══════════╩═══════════╩═════════════╩══════════╝ -/// -/// Byte 9 and 10: -/// The keep alive -/// -/// Byte 11: -/// Length of [`ConnectProperties`] -/// -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Connect { - /// Byte 7 - pub protocol_version: ProtocolVersion, - - /// 3.1.2.4 Clean Start Flag - /// bit 1 - pub clean_start: bool, - - /// 3.1.2.5 Will Flag through option - pub last_will: Option, - - /// 3.1.2.8 User Name Flag - pub username: Option>, - /// 3.1.2.9 Password Flag - pub password: Option>, - - /// 3.1.2.10 Keep Alive - /// Byte 9 and 10 - pub keep_alive: u16, - - /// 3.1.2.11 CONNECT Properties - pub connect_properties: ConnectProperties, - - /// 3.1.3.1 Client Identifier (ClientID) - pub client_id: Box, -} - -impl Default for Connect { - fn default() -> Self { - Self { - protocol_version: ProtocolVersion::V5, - clean_start: true, - last_will: None, - username: None, - password: None, - keep_alive: 60, - connect_properties: ConnectProperties::default(), - client_id: "MQRSTT".into(), - } - } -} - -impl VariableHeaderRead for Connect { - fn read(_: u8, _: usize, mut buf: Bytes) -> Result { - if String::read(&mut buf)? != "MQTT" { - return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_string())); - } - - let protocol_version = ProtocolVersion::read(&mut buf)?; - - let connect_flags = ConnectFlags::read(&mut buf)?; - - let clean_start = connect_flags.clean_start; - let keep_alive = buf.get_u16(); - - let connect_properties = ConnectProperties::read(&mut buf)?; - - let client_id = Box::::read(&mut buf)?; - let mut last_will = None; - if connect_flags.will_flag { - let retain = connect_flags.will_retain; - - last_will = Some(LastWill::read(connect_flags.will_qos, retain, &mut buf)?); - } - - let username = if connect_flags.username { Some(Box::::read(&mut buf)?) } else { None }; - let password = if connect_flags.password { Some(Box::::read(&mut buf)?) } else { None }; - - let connect = Connect { - protocol_version, - clean_start, - last_will, - username, - password, - keep_alive, - connect_properties, - client_id, - }; - - Ok(connect) - } -} - -impl VariableHeaderWrite for Connect { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - "MQTT".write(buf)?; - - self.protocol_version.write(buf)?; - - let mut connect_flags = ConnectFlags { - clean_start: self.clean_start, - ..Default::default() - }; - - if let Some(last_will) = &self.last_will { - connect_flags.will_flag = true; - connect_flags.will_retain = last_will.retain; - connect_flags.will_qos = last_will.qos; - } - connect_flags.username = self.username.is_some(); - connect_flags.password = self.password.is_some(); - - connect_flags.write(buf)?; - - buf.put_u16(self.keep_alive); - - self.connect_properties.write(buf)?; - - self.client_id.write(buf)?; - - if let Some(last_will) = &self.last_will { - last_will.write(buf)?; - } - if let Some(username) = &self.username { - username.write(buf)?; - } - if let Some(password) = &self.password { - password.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for Connect { - fn wire_len(&self) -> usize { - let mut len = "MQTT".wire_len() + 1 + 1 + 2; // protocol version, connect_flags and keep alive - - len += variable_integer_len(self.connect_properties.wire_len()); - len += self.connect_properties.wire_len(); - - if let Some(last_will) = &self.last_will { - len += last_will.wire_len(); - } - if let Some(username) = &self.username { - len += username.wire_len() - } - if let Some(password) = &self.password { - len += password.wire_len() - } - - len += self.client_id.wire_len(); - - len - } -} - -/// ╔═════╦═══════════╦══════════╦═════════════╦═════╦════╦═══════════╦═════════════╦══════════╗ -/// ║ Bit ║ 7 ║ 6 ║ 5 ║ 4 ║ 3 ║ 2 ║ 1 ║ 0 ║ -/// ╠═════╬═══════════╬══════════╬═════════════╬═════╩════╬═══════════╬═════════════╬══════════╣ -/// ║ ║ User Name ║ Password ║ Will Retain ║ Will QoS ║ Will Flag ║ Clean Start ║ Reserved ║ -/// ╚═════╩═══════════╩══════════╩═════════════╩══════════╩═══════════╩═════════════╩══════════╝ -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct ConnectFlags { - pub clean_start: bool, - pub will_flag: bool, - pub will_qos: QoS, - pub will_retain: bool, - pub password: bool, - pub username: bool, -} - -impl ConnectFlags { - pub fn from_u8(value: u8) -> Result { - Ok(Self { - clean_start: ((value & 0b00000010) >> 1) != 0, - will_flag: ((value & 0b00000100) >> 2) != 0, - will_qos: QoS::from_u8((value & 0b00011000) >> 3)?, - will_retain: ((value & 0b00100000) >> 5) != 0, - password: ((value & 0b01000000) >> 6) != 0, - username: ((value & 0b10000000) >> 7) != 0, - }) - } - - pub fn into_u8(&self) -> Result { - let byte = ((self.clean_start as u8) << 1) - | ((self.will_flag as u8) << 2) - | (self.will_qos.into_u8() << 3) - | ((self.will_retain as u8) << 5) - | ((self.password as u8) << 6) - | ((self.username as u8) << 7); - Ok(byte) - } -} - -impl Default for ConnectFlags { - fn default() -> Self { - Self { - clean_start: false, - will_flag: false, - will_qos: QoS::AtMostOnce, - will_retain: false, - password: false, - username: false, - } - } -} - -impl MqttRead for ConnectFlags { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("ConnectFlags".to_string(), 0, 1)); - } - - let byte = buf.get_u8(); - - ConnectFlags::from_u8(byte) - } -} - -impl MqttWrite for ConnectFlags { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - buf.put_u8(self.into_u8()?); - Ok(()) - } -} - -/// Connect Properties -/// -/// The wire representation starts with the length of all properties after which -/// the identifiers and their actual value are given -/// -/// 3.1.2.11.1 Property Length -/// The length of the Properties in the CONNECT packet Variable Header encoded as a Variable Byte Integer. -/// Followed by all possible connect properties: -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct ConnectProperties { - /// 3.1.2.11.2 Session Expiry Interval - /// 17 (0x11) Byte Identifier of the Session Expiry Interval - pub session_expiry_interval: Option, - - /// 3.1.2.11.3 Receive Maximum - /// 33 (0x21) Byte, Identifier of the Receive Maximum - pub receive_maximum: Option, - - /// 3.1.2.11.4 Maximum Packet Size - /// 39 (0x27) Byte, Identifier of the Maximum Packet Size - pub maximum_packet_size: Option, - - /// 3.1.2.11.5 Topic Alias Maximum - /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum - pub topic_alias_maximum: Option, - - /// 3.1.2.11.6 Request Response Information - /// 25 (0x19) Byte, Identifier of the Request Response Information - pub request_response_information: Option, - - /// 3.1.2.11.7 Request Problem Information - /// 23 (0x17) Byte, Identifier of the Request Problem Information - pub request_problem_information: Option, - - /// 3.1.2.11.8 User Property - /// 38 (0x26) Byte, Identifier of the User Property - pub user_properties: Vec<(Box, Box)>, - - /// 3.1.2.11.9 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method - pub authentication_method: Option>, - - /// 3.1.2.11.10 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - pub authentication_data: Bytes, -} - -impl MqttWrite for ConnectProperties { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(session_expiry_interval) = self.session_expiry_interval { - PropertyType::SessionExpiryInterval.write(buf)?; - buf.put_u32(session_expiry_interval); - } - if let Some(receive_maximum) = self.receive_maximum { - PropertyType::ReceiveMaximum.write(buf)?; - buf.put_u16(receive_maximum); - } - if let Some(maximum_packet_size) = self.maximum_packet_size { - PropertyType::MaximumPacketSize.write(buf)?; - buf.put_u32(maximum_packet_size); - } - if let Some(topic_alias_maximum) = self.topic_alias_maximum { - PropertyType::TopicAliasMaximum.write(buf)?; - buf.put_u16(topic_alias_maximum); - } - if let Some(request_response_information) = self.request_response_information { - PropertyType::RequestResponseInformation.write(buf)?; - buf.put_u8(request_response_information); - } - if let Some(request_problem_information) = self.request_problem_information { - PropertyType::RequestProblemInformation.write(buf)?; - buf.put_u8(request_problem_information); - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - if let Some(authentication_method) = &self.authentication_method { - PropertyType::AuthenticationMethod.write(buf)?; - authentication_method.write(buf)?; - } - if !self.authentication_data.is_empty() { - if self.authentication_method.is_none() { - return Err(SerializeError::AuthDataWithoutAuthMethod); - } - PropertyType::AuthenticationData.write(buf)?; - self.authentication_data.write(buf)?; - } - - Ok(()) - } -} - -impl MqttRead for ConnectProperties { - fn read(buf: &mut Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("ConnectProperties".to_string(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut property_data)? { - PropertyType::SessionExpiryInterval => { - if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.session_expiry_interval = Some(property_data.get_u32()); - } - PropertyType::ReceiveMaximum => { - if properties.receive_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); - } - properties.receive_maximum = Some(property_data.get_u16()); - } - PropertyType::MaximumPacketSize => { - if properties.maximum_packet_size.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); - } - properties.maximum_packet_size = Some(property_data.get_u32()); - } - PropertyType::TopicAliasMaximum => { - if properties.topic_alias_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); - } - properties.topic_alias_maximum = Some(property_data.get_u16()); - } - PropertyType::RequestResponseInformation => { - if properties.request_response_information.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation)); - } - properties.request_response_information = Some(property_data.get_u8()); - } - PropertyType::RequestProblemInformation => { - if properties.request_problem_information.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::RequestProblemInformation)); - } - properties.request_problem_information = Some(property_data.get_u8()); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::AuthenticationMethod => { - if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); - } - properties.authentication_method = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationData => { - if properties.authentication_data.is_empty() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); - } - properties.authentication_data = Bytes::read(&mut property_data)?; - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), - } - - if property_data.is_empty() { - break; - } - } - - if !properties.authentication_data.is_empty() && properties.authentication_method.is_none() { - return Err(DeserializeError::MalformedPacketWithInfo("Authentication data is not empty while authentication method is".to_string())); - } - - Ok(properties) - } -} - -impl WireLength for ConnectProperties { - fn wire_len(&self) -> usize { - let mut len: usize = 0; - - if self.session_expiry_interval.is_some() { - len += 1 + 4; - } - if self.receive_maximum.is_some() { - len += 1 + 2; - } - if self.maximum_packet_size.is_some() { - len += 1 + 4; - } - if self.topic_alias_maximum.is_some() { - len += 1 + 2; - } - if self.request_response_information.is_some() { - len += 2; - } - if self.request_problem_information.is_some() { - len += 2; - } - for (key, value) in &self.user_properties { - len += 1; - len += key.wire_len(); - len += value.wire_len(); - } - if let Some(authentication_method) = &self.authentication_method { - len += 1 + authentication_method.wire_len(); - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - len += 1 + self.authentication_data.wire_len(); - } - - len - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct LastWill { - /// 3.1.2.6 Will QoS - pub qos: QoS, - /// 3.1.2.7 Will Retain - pub retain: bool, - - /// 3.1.3.2 Will properties - pub last_will_properties: LastWillProperties, - /// 3.1.3.3 Will Topic - pub topic: Box, - /// 3.1.3.4 Will payload - pub payload: Bytes, -} - -impl LastWill { - pub fn new, P: Into>>(qos: QoS, retain: bool, topic: T, payload: P) -> LastWill { - Self { - qos, - retain, - last_will_properties: LastWillProperties::default(), - topic: topic.as_ref().into(), - payload: Bytes::from(payload.into()), - } - } - pub fn read(qos: QoS, retain: bool, buf: &mut Bytes) -> Result { - let last_will_properties = LastWillProperties::read(buf)?; - let topic = Box::::read(buf)?; - let payload = Bytes::read(buf)?; - - Ok(Self { - qos, - retain, - topic, - payload, - last_will_properties, - }) - } -} - -impl MqttWrite for LastWill { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - self.last_will_properties.write(buf)?; - self.topic.write(buf)?; - self.payload.write(buf)?; - Ok(()) - } -} - -impl WireLength for LastWill { - fn wire_len(&self) -> usize { - let property_len = self.last_will_properties.wire_len(); - - self.topic.wire_len() + self.payload.wire_len() + variable_integer_len(property_len) + property_len - } -} - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct LastWillProperties { - /// 3.1.3.2.2 Will Delay Interval - delay_interval: Option, - /// 3.1.3.2.3 Payload Format Indicator - payload_format_indicator: Option, - /// 3.1.3.2.4 Message Expiry Interval - message_expiry_interval: Option, - /// 3.1.3.2.5 Content Type - content_type: Option>, - /// 3.1.3.2.6 Response Topic - response_topic: Option>, - /// 3.1.3.2.7 Correlation Data - correlation_data: Option, - /// 3.1.3.2.8 User Property - user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for LastWillProperties { - fn read(buf: &mut Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("LastWillProperties".to_string(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut property_data)? { - PropertyType::WillDelayInterval => { - if properties.delay_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::WillDelayInterval)); - } - properties.delay_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::PayloadFormatIndicator => { - if properties.payload_format_indicator.is_none() { - return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); - } - properties.payload_format_indicator = Some(u8::read(&mut property_data)?); - } - PropertyType::MessageExpiryInterval => { - if properties.message_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); - } - properties.message_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::ContentType => { - if properties.content_type.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); - } - properties.content_type = Some(Box::::read(&mut property_data)?); - } - PropertyType::ResponseTopic => { - if properties.response_topic.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); - } - properties.response_topic = Some(Box::::read(&mut property_data)?); - } - PropertyType::CorrelationData => { - if properties.correlation_data.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); - } - properties.correlation_data = Some(Bytes::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for LastWillProperties { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(delay_interval) = self.delay_interval { - PropertyType::WillDelayInterval.write(buf)?; - buf.put_u32(delay_interval); - } - if let Some(payload_format_indicator) = self.payload_format_indicator { - PropertyType::PayloadFormatIndicator.write(buf)?; - buf.put_u8(payload_format_indicator); - } - if let Some(message_expiry_interval) = self.message_expiry_interval { - PropertyType::MessageExpiryInterval.write(buf)?; - buf.put_u32(message_expiry_interval); - } - if let Some(content_type) = &self.content_type { - PropertyType::ContentType.write(buf)?; - content_type.write(buf)?; - } - if let Some(response_topic) = &self.response_topic { - PropertyType::ResponseTopic.write(buf)?; - response_topic.write(buf)?; - } - if let Some(correlation_data) = &self.correlation_data { - PropertyType::CorrelationData.write(buf)?; - correlation_data.write(buf)?; - } - if !self.user_properties.is_empty() { - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - } - Ok(()) - } -} - -impl WireLength for LastWillProperties { - fn wire_len(&self) -> usize { - let mut len: usize = 0; - - if self.delay_interval.is_some() { - len += 5; - } - if self.payload_format_indicator.is_some() { - len += 2; - } - if self.message_expiry_interval.is_some() { - len += 5; - } - // +1 for the property type - len += self.content_type.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); - len += self.response_topic.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); - len += self.correlation_data.as_ref().map_or_else(|| 0, |b| b.wire_len() + 1); - for (key, value) in &self.user_properties { - len += key.wire_len() + value.wire_len() + 1; - } - - len - } -} - -#[cfg(test)] -mod tests { - use crate::packets::{ - mqtt_traits::{MqttWrite, VariableHeaderRead, VariableHeaderWrite}, - QoS, - }; - - use super::{Connect, ConnectFlags, LastWill}; - - #[test] - fn read_connect() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - // 0x10, - // 39, // packet type, flags and remaining len - 0x00, - 0x04, - b'M', - b'Q', - b'T', - b'T', - 0x05, - 0b1100_1110, // Connect Flags, username, password, will retain=false, will qos=1, last_will, clean_start - 0x00, // Keep alive = 10 sec - 0x0a, - 0x00, // Length of Connect properties - 0x00, // client_id length - 0x04, - b't', // client_id - b'e', - b's', - b't', - 0x00, // Will properties length - 0x00, // length topic - 0x02, - b'/', // Will topic = '/a' - b'a', - 0x00, // Will payload length - 0x0B, - b'h', // Will payload = 'hello world' - b'e', - b'l', - b'l', - b'o', - b' ', - b'w', - b'o', - b'r', - b'l', - b'd', - 0x00, // length username - 0x04, - b'u', // username = 'user' - b's', - b'e', - b'r', - 0x00, // length password - 0x04, - b'p', // Password = 'pass' - b'a', - b's', - b's', - 0xAB, // extra packets in the stream - 0xCD, - 0xEF, - ]; - - buf.extend_from_slice(packet); - let c = Connect::read(0, 0, buf.into()).unwrap(); - - dbg!(c); - } - - #[test] - fn read_and_write_connect() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - // 0x10, - // 39, // packet type, flags and remaining len - 0x00, - 0x04, - b'M', - b'Q', - b'T', - b'T', - 0x05, // variable header - 0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, // Keep alive = 10 sec - 0x0a, - 0x00, // Length of Connect properties - 0x00, // client_id length - 0x04, - b't', // client_id - b'e', - b's', - b't', - 0x00, // Will properties length - 0x00, // length topic - 0x02, - b'/', // Will topic = '/a' - b'a', - 0x00, // Will payload length - 0x0B, - b'h', // Will payload = 'hello world' - b'e', - b'l', - b'l', - b'o', - b' ', - b'w', - b'o', - b'r', - b'l', - b'd', - 0x00, // length username - 0x04, - b'u', // username - b's', - b'e', - b'r', - 0x00, // length password - 0x04, - b'p', // payload. password = 'pass' - b'a', - b's', - b's', - ]; - - buf.extend_from_slice(packet); - let c = Connect::read(0, 0, buf.into()).unwrap(); - - let mut write_buf = bytes::BytesMut::new(); - c.write(&mut write_buf).unwrap(); - - assert_eq!(packet.to_vec(), write_buf.to_vec()); - - dbg!(c); - } - - #[test] - fn parsing_last_will() { - let last_will = &[ - 0x00, // Will properties length - 0x00, // length topic - 0x02, b'/', // Will topic = '/a' - b'a', 0x00, // Will payload length - 0x0B, b'h', // Will payload = 'hello world' - b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', - ]; - let mut buf = bytes::Bytes::from_static(last_will); - - assert!(LastWill::read(QoS::AtLeastOnce, false, &mut buf).is_ok()); - } - - #[test] - fn read_and_write_connect2() { - let _packet = [ - 0x10, 0x1d, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, - ]; - - let data = [ - 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, - ]; - - let mut buf = bytes::BytesMut::new(); - buf.extend_from_slice(&data); - - let c = Connect::read(0, 0, buf.into()).unwrap(); - - dbg!(c.clone()); - - let mut write_buf = bytes::BytesMut::new(); - c.write(&mut write_buf).unwrap(); - - assert_eq!(data.to_vec(), write_buf.to_vec()); - } - - #[test] - fn parsing_and_writing_last_will() { - let last_will = &[ - 0x00, // Will properties length - 0x00, // length topic - 0x02, b'/', // Will topic = '/a' - b'a', 0x00, // Will payload length - 0x0B, b'h', // Will payload = 'hello world' - b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', - ]; - let mut buf = bytes::Bytes::from_static(last_will); - - let lw = LastWill::read(QoS::AtLeastOnce, false, &mut buf).unwrap(); - - let mut write_buf = bytes::BytesMut::new(); - lw.write(&mut write_buf).unwrap(); - - assert_eq!(last_will.to_vec(), write_buf.to_vec()); - } - - #[test] - fn connect_flag() { - let byte = 0b1100_1110; - let flags = ConnectFlags::from_u8(byte).unwrap(); - assert_eq!(byte, flags.into_u8().unwrap()); - } -} diff --git a/mqrstt/src/packets/connect/connect_flags.rs b/mqrstt/src/packets/connect/connect_flags.rs new file mode 100644 index 0000000..129a132 --- /dev/null +++ b/mqrstt/src/packets/connect/connect_flags.rs @@ -0,0 +1,104 @@ +use bytes::{Buf, BufMut}; + +use tokio::io::AsyncReadExt; + +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, + QoS, +}; + +/// The connect flags describe some information related the session. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ConnectFlags { + /// Indicates whether to start a new session or continue an existing one. + pub clean_start: bool, + /// Specifies if a Will message is included. + pub will_flag: bool, + /// Defines the Quality of Service level for the Will message. + pub will_qos: QoS, + /// Indicates if the Will message should be retained by the broker. + pub will_retain: bool, + /// Shows if a password is included in the payload. + pub password: bool, + /// Shows if a username is included in the payload. + pub username: bool, +} + +impl ConnectFlags { + pub fn from_u8(value: u8) -> Result { + Ok(Self { + clean_start: ((value & 0b00000010) >> 1) != 0, + will_flag: ((value & 0b00000100) >> 2) != 0, + will_qos: QoS::from_u8((value & 0b00011000) >> 3)?, + will_retain: ((value & 0b00100000) >> 5) != 0, + password: ((value & 0b01000000) >> 6) != 0, + username: ((value & 0b10000000) >> 7) != 0, + }) + } + + pub fn into_u8(&self) -> Result { + let byte = ((self.clean_start as u8) << 1) + | ((self.will_flag as u8) << 2) + | (self.will_qos.into_u8() << 3) + | ((self.will_retain as u8) << 5) + | ((self.password as u8) << 6) + | ((self.username as u8) << 7); + Ok(byte) + } +} + +impl Default for ConnectFlags { + fn default() -> Self { + Self { + clean_start: false, + will_flag: false, + will_qos: QoS::AtMostOnce, + will_retain: false, + password: false, + username: false, + } + } +} + +impl MqttRead for ConnectFlags { + fn read(buf: &mut bytes::Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + let byte = buf.get_u8(); + + ConnectFlags::from_u8(byte) + } +} + +impl MqttAsyncRead for ConnectFlags +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let byte = stream.read_u8().await?; + Ok((ConnectFlags::from_u8(byte)?, 1)) + } +} + +impl MqttWrite for ConnectFlags { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + buf.put_u8(self.into_u8()?); + Ok(()) + } +} + +impl MqttAsyncWrite for ConnectFlags +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let byte = self.into_u8()?; + stream.write_u8(byte).await?; + + Ok(1) + } +} diff --git a/mqrstt/src/packets/connect/connect_properties.rs b/mqrstt/src/packets/connect/connect_properties.rs new file mode 100644 index 0000000..7b3f13a --- /dev/null +++ b/mqrstt/src/packets/connect/connect_properties.rs @@ -0,0 +1,157 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::VariableInteger; +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttRead, MqttWrite}, + PacketType, PropertyType, WireLength, +}; + +// / +// / The wire representation starts with the length of all properties after which +// / the identifiers and their actual value are given +// / +// / 3.1.2.11.1 Property Length +// / The length of the Properties in the CONNECT packet Variable Header encoded as a Variable Byte Integer. +// / Followed by all possible connect properties: +crate::packets::macros::define_properties!( + /// Connect Properties + ConnectProperties, + SessionExpiryInterval, + ReceiveMaximum, + MaximumPacketSize, + TopicAliasMaximum, + RequestResponseInformation, + RequestProblemInformation, + UserProperty, + AuthenticationMethod, + AuthenticationData +); + +impl MqttWrite for ConnectProperties { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + PropertyType::SessionExpiryInterval.write(buf)?; + buf.put_u32(session_expiry_interval); + } + if let Some(receive_maximum) = self.receive_maximum { + PropertyType::ReceiveMaximum.write(buf)?; + buf.put_u16(receive_maximum); + } + if let Some(maximum_packet_size) = self.maximum_packet_size { + PropertyType::MaximumPacketSize.write(buf)?; + buf.put_u32(maximum_packet_size); + } + if let Some(topic_alias_maximum) = self.topic_alias_maximum { + PropertyType::TopicAliasMaximum.write(buf)?; + buf.put_u16(topic_alias_maximum); + } + if let Some(request_response_information) = self.request_response_information { + PropertyType::RequestResponseInformation.write(buf)?; + buf.put_u8(request_response_information); + } + if let Some(request_problem_information) = self.request_problem_information { + PropertyType::RequestProblemInformation.write(buf)?; + buf.put_u8(request_problem_information); + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + if let Some(authentication_method) = &self.authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if let Some(authentication_data) = &self.authentication_data { + if self.authentication_method.is_none() { + return Err(SerializeError::AuthDataWithoutAuthMethod); + } + PropertyType::AuthenticationData.write(buf)?; + authentication_data.write(buf)?; + } + Ok(()) + } +} + +impl MqttRead for ConnectProperties { + fn read(buf: &mut Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut property_data)? { + PropertyType::SessionExpiryInterval => { + if properties.session_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.session_expiry_interval = Some(property_data.get_u32()); + } + PropertyType::ReceiveMaximum => { + if properties.receive_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); + } + properties.receive_maximum = Some(property_data.get_u16()); + } + PropertyType::MaximumPacketSize => { + if properties.maximum_packet_size.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); + } + properties.maximum_packet_size = Some(property_data.get_u32()); + } + PropertyType::TopicAliasMaximum => { + if properties.topic_alias_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); + } + properties.topic_alias_maximum = Some(property_data.get_u16()); + } + PropertyType::RequestResponseInformation => { + if properties.request_response_information.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation)); + } + properties.request_response_information = Some(property_data.get_u8()); + } + PropertyType::RequestProblemInformation => { + if properties.request_problem_information.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::RequestProblemInformation)); + } + properties.request_problem_information = Some(property_data.get_u8()); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::AuthenticationMethod => { + if properties.authentication_method.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); + } + properties.authentication_method = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationData => { + if properties.authentication_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); + } + properties.authentication_data = Some(Vec::::read(&mut property_data)?); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), + } + + if property_data.is_empty() { + break; + } + } + + if properties.authentication_data.as_ref().is_some_and(|data| !data.is_empty()) && properties.authentication_method.is_none() { + return Err(DeserializeError::MalformedPacketWithInfo("Authentication data is not empty while authentication method is".to_string())); + } + + Ok(properties) + } +} diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs new file mode 100644 index 0000000..bc076bc --- /dev/null +++ b/mqrstt/src/packets/connect/last_will.rs @@ -0,0 +1,100 @@ +use bytes::{Bytes, BytesMut}; + +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, + QoS, WireLength, +}; + +use super::{LastWillProperties, VariableInteger}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LastWill { + /// 3.1.2.6 Will QoS + pub qos: QoS, + /// 3.1.2.7 Will Retain + pub retain: bool, + + /// 3.1.3.2 Will properties + pub last_will_properties: LastWillProperties, + /// 3.1.3.3 Will Topic + pub topic: Box, + /// 3.1.3.4 Will payload + pub payload: Vec, +} + +impl LastWill { + pub fn new, P: Into>>(qos: QoS, retain: bool, topic: T, payload: P) -> LastWill { + Self { + qos, + retain, + last_will_properties: LastWillProperties::default(), + topic: topic.as_ref().into(), + payload: payload.into(), + } + } + pub(crate) fn read(qos: QoS, retain: bool, buf: &mut Bytes) -> Result { + let last_will_properties = LastWillProperties::read(buf)?; + let topic = Box::::read(buf)?; + let payload = Vec::::read(buf)?; + + Ok(Self { + qos, + retain, + topic, + payload, + last_will_properties, + }) + } + pub(crate) async fn async_read(qos: QoS, retain: bool, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> + where + S: tokio::io::AsyncRead + Unpin, + { + let (last_will_properties, last_will_properties_read_bytes) = LastWillProperties::async_read(stream).await?; + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + let (payload, payload_read_bytes) = Vec::::async_read(stream).await?; + + let total_read_bytes = last_will_properties_read_bytes + topic_read_bytes + payload_read_bytes; + + Ok(( + Self { + qos, + retain, + last_will_properties, + topic, + payload, + }, + total_read_bytes, + )) + } +} + +impl MqttWrite for LastWill { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.last_will_properties.write(buf)?; + self.topic.write(buf)?; + self.payload.write(buf)?; + Ok(()) + } +} + +impl MqttAsyncWrite for LastWill +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let properties_written = self.last_will_properties.async_write(stream).await?; + let topic_written = self.topic.async_write(stream).await?; + let payload_written = self.payload.async_write(stream).await?; + + Ok(properties_written + topic_written + payload_written) + } +} + +impl WireLength for LastWill { + fn wire_len(&self) -> usize { + let property_len = self.last_will_properties.wire_len(); + + self.topic.wire_len() + self.payload.wire_len() + property_len.variable_integer_len() + property_len + } +} diff --git a/mqrstt/src/packets/connect/last_will_properties.rs b/mqrstt/src/packets/connect/last_will_properties.rs new file mode 100644 index 0000000..0a580ba --- /dev/null +++ b/mqrstt/src/packets/connect/last_will_properties.rs @@ -0,0 +1,123 @@ +use bytes::{BufMut, Bytes, BytesMut}; + +use crate::packets::VariableInteger; +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttRead, MqttWrite}, + PacketType, PropertyType, WireLength, +}; + +crate::packets::macros::define_properties!( + /// Last Will Properties + LastWillProperties, + WillDelayInterval, + PayloadFormatIndicator, + MessageExpiryInterval, + ContentType, + ResponseTopic, + CorrelationData, + UserProperty +); + +impl MqttRead for LastWillProperties { + fn read(buf: &mut Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut property_data)? { + PropertyType::WillDelayInterval => { + if properties.will_delay_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::WillDelayInterval)); + } + properties.will_delay_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::PayloadFormatIndicator => { + if properties.payload_format_indicator.is_none() { + return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); + } + properties.payload_format_indicator = Some(u8::read(&mut property_data)?); + } + PropertyType::MessageExpiryInterval => { + if properties.message_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); + } + properties.message_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::ContentType => { + if properties.content_type.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); + } + properties.content_type = Some(Box::::read(&mut property_data)?); + } + PropertyType::ResponseTopic => { + if properties.response_topic.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); + } + properties.response_topic = Some(Box::::read(&mut property_data)?); + } + PropertyType::CorrelationData => { + if properties.correlation_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); + } + properties.correlation_data = Some(Vec::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for LastWillProperties { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(delay_interval) = self.will_delay_interval { + PropertyType::WillDelayInterval.write(buf)?; + buf.put_u32(delay_interval); + } + if let Some(payload_format_indicator) = self.payload_format_indicator { + PropertyType::PayloadFormatIndicator.write(buf)?; + buf.put_u8(payload_format_indicator); + } + if let Some(message_expiry_interval) = self.message_expiry_interval { + PropertyType::MessageExpiryInterval.write(buf)?; + buf.put_u32(message_expiry_interval); + } + if let Some(content_type) = &self.content_type { + PropertyType::ContentType.write(buf)?; + content_type.write(buf)?; + } + if let Some(response_topic) = &self.response_topic { + PropertyType::ResponseTopic.write(buf)?; + response_topic.write(buf)?; + } + if let Some(correlation_data) = &self.correlation_data { + PropertyType::CorrelationData.write(buf)?; + correlation_data.write(buf)?; + } + if !self.user_properties.is_empty() { + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + } + Ok(()) + } +} diff --git a/mqrstt/src/packets/connect/mod.rs b/mqrstt/src/packets/connect/mod.rs new file mode 100644 index 0000000..cb5a611 --- /dev/null +++ b/mqrstt/src/packets/connect/mod.rs @@ -0,0 +1,559 @@ +mod last_will_properties; +pub use last_will_properties::LastWillProperties; + +mod connect_flags; +pub use connect_flags::ConnectFlags; + +mod connect_properties; +pub use connect_properties::ConnectProperties; + +mod last_will; +pub use last_will::LastWill; + +use crate::packets::error::ReadError; + +use super::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + ProtocolVersion, VariableInteger, WireLength, +}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::AsyncReadExt; + +/// Connect packet send by the client to the server to initialize a connection. +/// +/// Variable Header +/// - Protocol Name and Version: Identifies the MQTT protocol and version. +/// - Connect Flags: Options like clean start, will flag, will QoS, will retain, password flag, and username flag. +/// - Keep Alive Interval: Maximum time interval between messages. +/// - Properties: Optional settings such as session expiry interval, receive maximum, maximum packet size, and topic alias maximum. +/// +/// Payload +/// - Client Identifier: Unique ID for the client. +/// - Will Message: Optional message sent if the client disconnects unexpectedly. +/// - Username and Password: Optional credentials for authentication. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Connect { + pub protocol_version: ProtocolVersion, + + /// 3.1.2.4 Clean Start Flag + pub clean_start: bool, + /// 3.1.2.5 Will Flag through option + pub last_will: Option, + + /// 3.1.2.8 User Name Flag + pub username: Option>, + /// 3.1.2.9 Password Flag + pub password: Option>, + /// 3.1.2.10 Keep Alive + pub keep_alive: u16, + /// 3.1.2.11 CONNECT Properties + pub connect_properties: ConnectProperties, + + /// 3.1.3.1 Client Identifier (ClientID) + pub client_id: Box, +} + +impl Default for Connect { + fn default() -> Self { + Self { + protocol_version: ProtocolVersion::V5, + clean_start: true, + last_will: None, + username: None, + password: None, + keep_alive: 60, + connect_properties: ConnectProperties::default(), + client_id: "MQRSTT".into(), + } + } +} + +impl PacketRead for Connect { + fn read(_: u8, _: usize, mut buf: Bytes) -> Result { + if String::read(&mut buf)? != "MQTT" { + return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_string())); + } + + let protocol_version = ProtocolVersion::read(&mut buf)?; + + let connect_flags = ConnectFlags::read(&mut buf)?; + + let clean_start = connect_flags.clean_start; + let keep_alive = buf.get_u16(); + + let connect_properties = ConnectProperties::read(&mut buf)?; + + let client_id = Box::::read(&mut buf)?; + let mut last_will = None; + if connect_flags.will_flag { + let retain = connect_flags.will_retain; + + last_will = Some(LastWill::read(connect_flags.will_qos, retain, &mut buf)?); + } + + let username = if connect_flags.username { Some(Box::::read(&mut buf)?) } else { None }; + let password = if connect_flags.password { Some(Box::::read(&mut buf)?) } else { None }; + + let connect = Connect { + protocol_version, + clean_start, + last_will, + username, + password, + keep_alive, + connect_properties, + client_id, + }; + + Ok(connect) + } +} + +impl PacketAsyncRead for Connect +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let mut total_read_bytes = 0; + let expected_protocol = [0x00, 0x04, b'M', b'Q', b'T', b'T']; + let mut protocol = [0u8; 6]; + stream.read_exact(&mut protocol).await?; + + if protocol != expected_protocol { + return Err(ReadError::DeserializeError(DeserializeError::MalformedPacketWithInfo(format!("Protocol not MQTT: {:?}", protocol)))); + } + let (protocol_version, _) = ProtocolVersion::async_read(stream).await?; + let (connect_flags, _) = ConnectFlags::async_read(stream).await?; + // Add "MQTT", protocol version and connect flags read bytes + total_read_bytes += 6 + 1 + 1; + + let clean_start = connect_flags.clean_start; + let keep_alive = stream.read_u16().await?; + // Add keep alive read bytes + total_read_bytes += 2; + + let (connect_properties, prop_read_bytes) = ConnectProperties::async_read(stream).await?; + let (client_id, client_read_bytes) = Box::::async_read(stream).await?; + total_read_bytes += prop_read_bytes + client_read_bytes; + + let last_will = if connect_flags.will_flag { + let retain = connect_flags.will_retain; + let (last_will, last_will_read_bytes) = LastWill::async_read(connect_flags.will_qos, retain, stream).await?; + total_read_bytes += last_will_read_bytes; + Some(last_will) + } else { + None + }; + + let (username, username_read_bytes) = if connect_flags.username { + let (username, username_read_bytes) = Box::::async_read(stream).await?; + (Some(username), username_read_bytes) + } else { + (None, 0) + }; + let (password, password_read_bytes) = if connect_flags.password { + let (password, password_read_bytes) = Box::::async_read(stream).await?; + (Some(password), password_read_bytes) + } else { + (None, 0) + }; + + total_read_bytes += username_read_bytes + password_read_bytes; + + let connect = Connect { + protocol_version, + clean_start, + last_will, + username, + password, + keep_alive, + connect_properties, + client_id, + }; + Ok((connect, total_read_bytes)) + } +} + +impl PacketWrite for Connect { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + "MQTT".write(buf)?; + + self.protocol_version.write(buf)?; + + let mut connect_flags = ConnectFlags { + clean_start: self.clean_start, + ..Default::default() + }; + + if let Some(last_will) = &self.last_will { + connect_flags.will_flag = true; + connect_flags.will_retain = last_will.retain; + connect_flags.will_qos = last_will.qos; + } + connect_flags.username = self.username.is_some(); + connect_flags.password = self.password.is_some(); + + connect_flags.write(buf)?; + + buf.put_u16(self.keep_alive); + + self.connect_properties.write(buf)?; + + self.client_id.write(buf)?; + + if let Some(last_will) = &self.last_will { + last_will.write(buf)?; + } + if let Some(username) = &self.username { + username.write(buf)?; + } + if let Some(password) = &self.password { + password.write(buf)?; + } + Ok(()) + } +} + +impl crate::packets::mqtt_trait::PacketAsyncWrite for Connect +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 6 // protocol header + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + let protocol = [0x00, 0x04, b'M', b'Q', b'T', b'T']; + // We allready start with 6 as total written bytes thus dont add anymore + stream.write_all(&protocol).await?; + + self.protocol_version.async_write(stream).await?; + + let mut connect_flags = ConnectFlags { + clean_start: self.clean_start, + username: self.username.is_some(), + password: self.password.is_some(), + ..Default::default() + }; + + if let Some(last_will) = &self.last_will { + connect_flags.will_flag = true; + connect_flags.will_retain = last_will.retain; + connect_flags.will_qos = last_will.qos; + } + + connect_flags.async_write(stream).await?; + + stream.write_u16(self.keep_alive).await?; + + total_written_bytes += self.connect_properties.async_write(stream).await?; + + total_written_bytes += self.client_id.async_write(stream).await?; + + if let Some(last_will) = &self.last_will { + total_written_bytes += last_will.async_write(stream).await?; + } + if let Some(username) = &self.username { + total_written_bytes += username.async_write(stream).await?; + } + if let Some(password) = &self.password { + total_written_bytes += password.async_write(stream).await?; + } + + Ok(total_written_bytes) + } + } +} + +impl WireLength for Connect { + fn wire_len(&self) -> usize { + let mut len = "MQTT".wire_len() + 1 + 1 + 2; // protocol version, connect_flags and keep alive + + len += self.connect_properties.wire_len().variable_integer_len(); + len += self.connect_properties.wire_len(); + + if let Some(last_will) = &self.last_will { + len += last_will.wire_len(); + } + if let Some(username) = &self.username { + len += username.wire_len() + } + if let Some(password) = &self.password { + len += password.wire_len() + } + + len += self.client_id.wire_len(); + + len + } +} + +#[cfg(test)] +mod tests { + use crate::packets::{ + mqtt_trait::{MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + QoS, + }; + + use super::{Connect, ConnectFlags, LastWill}; + + #[test] + fn read_connect() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + // 0x10, + // 39, // packet type, flags and remaining len + 0x00, + 0x04, + b'M', + b'Q', + b'T', + b'T', + 0x05, + 0b1100_1110, // Connect Flags, username, password, will retain=false, will qos=1, last_will, clean_start + 0x00, // Keep alive = 10 sec + 0x0a, + 0x00, // Length of Connect properties + 0x00, // client_id length + 0x04, + b't', // client_id + b'e', + b's', + b't', + 0x00, // Will properties length + 0x00, // length topic + 0x02, + b'/', // Will topic = '/a' + b'a', + 0x00, // Will payload length + 0x0B, + b'h', // Will payload = 'hello world' + b'e', + b'l', + b'l', + b'o', + b' ', + b'w', + b'o', + b'r', + b'l', + b'd', + 0x00, // length username + 0x04, + b'u', // username = 'user' + b's', + b'e', + b'r', + 0x00, // length password + 0x04, + b'p', // Password = 'pass' + b'a', + b's', + b's', + 0xAB, // extra packets in the stream + 0xCD, + 0xEF, + ]; + + buf.extend_from_slice(packet); + let c = Connect::read(0, 0, buf.into()).unwrap(); + + dbg!(c); + } + + #[test] + fn read_and_write_connect() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + // 0x10, + // 39, // packet type, flags and remaining len + 0x00, + 0x04, + b'M', + b'Q', + b'T', + b'T', + 0x05, // variable header + 0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session + 0x00, // Keep alive = 10 sec + 0x0a, + 0x00, // Length of Connect properties + 0x00, // client_id length + 0x04, + b't', // client_id + b'e', + b's', + b't', + 0x00, // Will properties length + 0x00, // length topic + 0x02, + b'/', // Will topic = '/a' + b'a', + 0x00, // Will payload length + 0x0B, + b'h', // Will payload = 'hello world' + b'e', + b'l', + b'l', + b'o', + b' ', + b'w', + b'o', + b'r', + b'l', + b'd', + 0x00, // length username + 0x04, + b'u', // username + b's', + b'e', + b'r', + 0x00, // length password + 0x04, + b'p', // payload. password = 'pass' + b'a', + b's', + b's', + ]; + + buf.extend_from_slice(packet); + let c = Connect::read(0, 0, buf.into()).unwrap(); + + let mut write_buf = bytes::BytesMut::new(); + c.write(&mut write_buf).unwrap(); + + assert_eq!(packet.to_vec(), write_buf.to_vec()); + + dbg!(c); + } + + #[tokio::test] + async fn read_async_and_write_connect() { + let packet = &[ + // 0x10, + // 39, // packet type, flags and remaining len + 0x00, + 0x04, + b'M', + b'Q', + b'T', + b'T', + 0x05, // variable header + 0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session + 0x00, // Keep alive = 10 sec + 0x0a, + 0x00, // Length of Connect properties + 0x00, // client_id length + 0x04, + b't', // client_id + b'e', + b's', + b't', + 0x00, // Will properties length + 0x00, // length topic + 0x02, + b'/', // Will topic = '/a' + b'a', + 0x00, // Will payload length + 0x0B, + b'h', // Will payload = 'hello world' + b'e', + b'l', + b'l', + b'o', + b' ', + b'w', + b'o', + b'r', + b'l', + b'd', + 0x00, // length username + 0x04, + b'u', // username + b's', + b'e', + b'r', + 0x00, // length password + 0x04, + b'p', // password = 'pass' + b'a', + b's', + b's', + ]; + + let (c, read_bytes) = Connect::async_read(0, 0, &mut packet.as_slice()).await.unwrap(); + assert_eq!(packet.len(), read_bytes); + + let mut write_buf = bytes::BytesMut::new(); + c.write(&mut write_buf).unwrap(); + + assert_eq!(packet.to_vec(), write_buf.to_vec()); + } + + #[test] + fn parsing_last_will() { + let last_will = &[ + 0x00, // Will properties length + 0x00, // length topic + 0x02, b'/', b'a', // Will topic = '/a' + 0x00, 0x0B, // Will payload length + b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', // Will payload = 'hello world' + ]; + let mut buf = bytes::Bytes::from_static(last_will); + assert!(LastWill::read(QoS::AtLeastOnce, false, &mut buf).is_ok()); + } + + #[test] + fn read_and_write_connect2() { + let _packet = [ + 0x10, 0x1d, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, + ]; + + let data = [ + 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, + ]; + + let mut buf = bytes::BytesMut::new(); + buf.extend_from_slice(&data); + + let c = Connect::read(0, 0, buf.into()).unwrap(); + + dbg!(c.clone()); + + let mut write_buf = bytes::BytesMut::new(); + c.write(&mut write_buf).unwrap(); + + assert_eq!(data.to_vec(), write_buf.to_vec()); + } + + #[test] + fn parsing_and_writing_last_will() { + let last_will = &[ + 0x00, // Will properties length + 0x00, // length topic + 0x02, b'/', // Will topic = '/a' + b'a', 0x00, // Will payload length + 0x0B, b'h', // Will payload = 'hello world' + b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', + ]; + let mut buf = bytes::Bytes::from_static(last_will); + + let lw = LastWill::read(QoS::AtLeastOnce, false, &mut buf).unwrap(); + + let mut write_buf = bytes::BytesMut::new(); + lw.write(&mut write_buf).unwrap(); + + assert_eq!(last_will.to_vec(), write_buf.to_vec()); + } + + #[test] + fn connect_flag() { + let byte = 0b1100_1110; + let flags = ConnectFlags::from_u8(byte).unwrap(); + assert_eq!(byte, flags.into_u8().unwrap()); + } +} diff --git a/mqrstt/src/packets/disconnect.rs b/mqrstt/src/packets/disconnect.rs deleted file mode 100644 index d2307cb..0000000 --- a/mqrstt/src/packets/disconnect.rs +++ /dev/null @@ -1,226 +0,0 @@ -use bytes::BufMut; - -use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::DisconnectReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct Disconnect { - pub reason_code: DisconnectReasonCode, - pub properties: DisconnectProperties, -} - -impl VariableHeaderRead for Disconnect { - fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { - let reason_code; - let properties; - if remaining_length == 0 { - reason_code = DisconnectReasonCode::NormalDisconnection; - properties = DisconnectProperties::default(); - } else { - reason_code = DisconnectReasonCode::read(&mut buf)?; - properties = DisconnectProperties::read(&mut buf)?; - } - - Ok(Self { reason_code, properties }) - } -} -impl VariableHeaderWrite for Disconnect { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { - self.reason_code.write(buf)?; - self.properties.write(buf)?; - } - Ok(()) - } -} -impl WireLength for Disconnect { - fn wire_len(&self) -> usize { - if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { - let property_len = self.properties.wire_len(); - // reasoncode, length of property length, property length - 1 + variable_integer_len(property_len) + property_len - } else { - 0 - } - } -} - -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct DisconnectProperties { - pub session_expiry_interval: Option, - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, - pub server_reference: Option>, -} - -impl MqttRead for DisconnectProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf).map_err(DeserializeError::from)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("DisconnectProperties".to_string(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::from_u8(u8::read(&mut property_data)?)? { - PropertyType::SessionExpiryInterval => { - if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.session_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(&mut property_data)?); - } - PropertyType::ServerReference => { - if properties.server_reference.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); - } - properties.server_reference = Some(Box::::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Disconnect)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for DisconnectProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(session_expiry_interval) = self.session_expiry_interval { - PropertyType::SessionExpiryInterval.write(buf)?; - buf.put_u32(session_expiry_interval); - } - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, val) in self.user_properties.iter() { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - val.write(buf)?; - } - if let Some(server_refrence) = &self.server_reference { - PropertyType::ServerReference.write(buf)?; - server_refrence.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for DisconnectProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if self.session_expiry_interval.is_some() { - len += 4 + 1; - } - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - len += self.user_properties.iter().fold(0, |acc, (k, v)| acc + k.wire_len() + v.wire_len() + 1); - if let Some(server_refrence) = &self.server_reference { - len += server_refrence.wire_len() + 1; - } - len - } -} - -#[cfg(test)] -mod tests { - use super::*; - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_write_and_read_disconnect() { - let mut buf = bytes::BytesMut::new(); - let packet = Disconnect { - properties: DisconnectProperties { - session_expiry_interval: Some(123), - reason_string: Some(Box::from("Some reason")), - user_properties: vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ], - server_reference: Some(Box::from("Server reference")), - }, - reason_code: DisconnectReasonCode::NormalDisconnection, - }; - - packet.write(&mut buf).unwrap(); - - let read_packet = Disconnect::read(0, buf.len(), buf.into()).unwrap(); - - assert_eq!(read_packet.properties.session_expiry_interval, Some(123)); - assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); - assert_eq!( - read_packet.properties.user_properties, - vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ] - ); - assert_eq!( - read_packet.properties.server_reference, - Some(Box::from("Server reference")) - ); - } -} - - - #[test] - fn test_write_and_read_disconnect_properties() { - let mut buf = bytes::BytesMut::new(); - let properties = DisconnectProperties { - session_expiry_interval: Some(123), - reason_string: Some(Box::from("Some reason")), - user_properties: vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ], - server_reference: Some(Box::from("Server reference")), - }; - - properties.write(&mut buf).unwrap(); - - let read_properties = DisconnectProperties::read(&mut buf.into()).unwrap(); - - assert_eq!(read_properties.session_expiry_interval, Some(123)); - assert_eq!(read_properties.reason_string, Some(Box::from("Some reason"))); - assert_eq!( - read_properties.user_properties, - vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ] - ); - assert_eq!( - read_properties.server_reference, - Some(Box::from("Server reference")) - ); - } -} \ No newline at end of file diff --git a/mqrstt/src/packets/disconnect/mod.rs b/mqrstt/src/packets/disconnect/mod.rs new file mode 100644 index 0000000..e077fb3 --- /dev/null +++ b/mqrstt/src/packets/disconnect/mod.rs @@ -0,0 +1,178 @@ +mod properties; +pub use properties::DisconnectProperties; + +mod reason_code; +pub use reason_code::DisconnectReasonCode; + +use super::{ + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, +}; + +/// The DISCONNECT Packet is the final packet. +/// The client sends this packet to the server to disconnect for example on calling [`crate::MqttClient::disconnect`]. +/// The server can send a disconnect packet to the client to indicate that the connection is being closed. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Disconnect { + pub reason_code: DisconnectReasonCode, + pub properties: DisconnectProperties, +} + +impl PacketAsyncRead for Disconnect +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + if remaining_length == 0 { + Ok(( + Self { + reason_code: DisconnectReasonCode::NormalDisconnection, + properties: DisconnectProperties::default(), + }, + 0, + )) + } else { + let (reason_code, reason_code_read_bytes) = DisconnectReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = DisconnectProperties::async_read(stream).await?; + + Ok((Self { reason_code, properties }, reason_code_read_bytes + properties_read_bytes)) + } + } +} + +impl PacketRead for Disconnect { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { + let reason_code; + let properties; + if remaining_length == 0 { + reason_code = DisconnectReasonCode::NormalDisconnection; + properties = DisconnectProperties::default(); + } else { + reason_code = DisconnectReasonCode::read(&mut buf)?; + properties = DisconnectProperties::read(&mut buf)?; + } + + Ok(Self { reason_code, properties }) + } +} +impl PacketWrite for Disconnect { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { + self.reason_code.write(buf)?; + self.properties.write(buf)?; + } + Ok(()) + } +} + +impl crate::packets::mqtt_trait::PacketAsyncWrite for Disconnect +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_written_bytes = 0; + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} + +impl WireLength for Disconnect { + fn wire_len(&self) -> usize { + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { + let property_len = self.properties.wire_len(); + // reasoncode, length of property length, property length + 1 + property_len.variable_integer_len() + property_len + } else { + 0 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_write_and_async_read_disconnect() { + let mut buf = bytes::BytesMut::new(); + let packet = Disconnect { + properties: DisconnectProperties { + session_expiry_interval: Some(123), + reason_string: Some(Box::from("Some reason")), + user_properties: vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2"))], + server_reference: Some(Box::from("Server reference")), + }, + reason_code: DisconnectReasonCode::NormalDisconnection, + }; + + packet.write(&mut buf).unwrap(); + + let mut stream = &*buf; + + let (read_packet, read_bytes) = Disconnect::async_read(0, buf.len(), &mut stream).await.unwrap(); + + assert_eq!(buf.len(), read_bytes); + assert_eq!(read_packet.properties.session_expiry_interval, Some(123)); + assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); + assert_eq!( + read_packet.properties.user_properties, + vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2")),] + ); + assert_eq!(read_packet.properties.server_reference, Some(Box::from("Server reference"))); + } + + #[test] + fn test_write_and_read_disconnect() { + let mut buf = bytes::BytesMut::new(); + let packet = Disconnect { + properties: DisconnectProperties { + session_expiry_interval: Some(123), + reason_string: Some(Box::from("Some reason")), + user_properties: vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2"))], + server_reference: Some(Box::from("Server reference")), + }, + reason_code: DisconnectReasonCode::NormalDisconnection, + }; + + packet.write(&mut buf).unwrap(); + + let read_packet = Disconnect::read(0, buf.len(), buf.into()).unwrap(); + + assert_eq!(read_packet.properties.session_expiry_interval, Some(123)); + assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); + assert_eq!( + read_packet.properties.user_properties, + vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2")),] + ); + assert_eq!(read_packet.properties.server_reference, Some(Box::from("Server reference"))); + } + + #[test] + fn test_write_and_read_disconnect_properties() { + let mut buf = bytes::BytesMut::new(); + let properties = DisconnectProperties { + session_expiry_interval: Some(123), + reason_string: Some(Box::from("Some reason")), + user_properties: vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2"))], + server_reference: Some(Box::from("Server reference")), + }; + + properties.write(&mut buf).unwrap(); + + let read_properties = DisconnectProperties::read(&mut buf.into()).unwrap(); + + assert_eq!(read_properties.session_expiry_interval, Some(123)); + assert_eq!(read_properties.reason_string, Some(Box::from("Some reason"))); + assert_eq!( + read_properties.user_properties, + vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2")),] + ); + assert_eq!(read_properties.server_reference, Some(Box::from("Server reference"))); + } +} diff --git a/mqrstt/src/packets/disconnect/properties.rs b/mqrstt/src/packets/disconnect/properties.rs new file mode 100644 index 0000000..1c5d2d0 --- /dev/null +++ b/mqrstt/src/packets/disconnect/properties.rs @@ -0,0 +1,80 @@ +use bytes::BufMut; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!(DisconnectProperties, SessionExpiryInterval, ReasonString, UserProperty, ServerReference); + +impl MqttRead for DisconnectProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::try_from(u8::read(&mut property_data)?)? { + PropertyType::SessionExpiryInterval => { + if properties.session_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.session_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(&mut property_data)?); + } + PropertyType::ServerReference => { + if properties.server_reference.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); + } + properties.server_reference = Some(Box::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Disconnect)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for DisconnectProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + PropertyType::SessionExpiryInterval.write(buf)?; + buf.put_u32(session_expiry_interval); + } + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, val) in self.user_properties.iter() { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + val.write(buf)?; + } + if let Some(server_refrence) = &self.server_reference { + PropertyType::ServerReference.write(buf)?; + server_refrence.write(buf)?; + } + Ok(()) + } +} diff --git a/mqrstt/src/packets/disconnect/reason_code.rs b/mqrstt/src/packets/disconnect/reason_code.rs new file mode 100644 index 0000000..e8b1c00 --- /dev/null +++ b/mqrstt/src/packets/disconnect/reason_code.rs @@ -0,0 +1,32 @@ +crate::packets::macros::reason_code!( + DisconnectReasonCode, + NormalDisconnection, + DisconnectWithWillMessage, + UnspecifiedError, + MalformedPacket, + ProtocolError, + ImplementationSpecificError, + NotAuthorized, + ServerBusy, + ServerShuttingDown, + KeepAliveTimeout, + SessionTakenOver, + TopicFilterInvalid, + TopicNameInvalid, + ReceiveMaximumExceeded, + TopicAliasInvalid, + PacketTooLarge, + MessageRateTooHigh, + QuotaExceeded, + AdministrativeAction, + PayloadFormatInvalid, + RetainNotSupported, + QosNotSupported, + UseAnotherServer, + ServerMoved, + SharedSubscriptionsNotSupported, + ConnectionRateExceeded, + MaximumConnectTime, + SubscriptionIdentifiersNotSupported, + WildcardSubscriptionsNotSupported +); diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index c54689c..68ede44 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -4,6 +4,22 @@ use thiserror::Error; use super::{PacketType, PropertyType}; +#[derive(Error, Debug)] +pub enum WriteError { + #[error("{0}")] + SerializeError(#[from] SerializeError), + #[error("{0}")] + IoError(#[from] std::io::Error), +} + +#[derive(Error, Debug)] +pub enum ReadError { + #[error("{0}")] + DeserializeError(#[from] DeserializeError), + #[error("{0}")] + IoError(#[from] std::io::Error), +} + #[derive(Error, Clone, Debug)] pub enum DeserializeError { #[error("Malformed packet: {0}")] @@ -22,11 +38,17 @@ pub enum DeserializeError { UnknownProtocolVersion, #[error("There is insufficient for {0} data ({1}) to take {2} bytes")] - InsufficientData(String, usize, usize), + InsufficientData(&'static str, usize, usize), #[error("There is insufficient to read the protocol version.")] InsufficientDataForProtocolVersion, + #[error("Read more data for the packet than indicated length")] + ReadTooMuchData(&'static str, usize, usize), + + #[error("While reading a packet {read} bytes was read, but the packet indicated a remaining length of {remaining_length} bytes")] + RemainingDataError { read: usize, remaining_length: usize }, + #[error("Reason code {0} is not allowed for packet type {1:?}")] UnexpectedReasonCode(u8, PacketType), diff --git a/mqrstt/src/packets/macros/mod.rs b/mqrstt/src/packets/macros/mod.rs new file mode 100644 index 0000000..22e829f --- /dev/null +++ b/mqrstt/src/packets/macros/mod.rs @@ -0,0 +1,5 @@ +mod properties_macros; +mod reason_code_macros; + +pub(crate) use properties_macros::*; +pub(crate) use reason_code_macros::*; diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs new file mode 100644 index 0000000..759c4b2 --- /dev/null +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -0,0 +1,885 @@ +macro_rules! define_properties { + ($(#[$attr:meta])* $name:ident, $($prop_variant:ident),*) => { + $crate::packets::macros::properties_struct!(@ + $(#[$attr])* + $name { $($prop_variant,)* } -> () + ); + + impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncRead + Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { + let (len, length_variable_integer) = ::read_async_variable_integer(stream).await?; + if len == 0 { + return Ok((Self::default(), length_variable_integer)); + } + + let mut properties = $name::default(); + + let mut read_property_bytes = 0; + loop { + let (prop, read_bytes) = crate::packets::PropertyType::async_read(stream).await?; + read_property_bytes += read_bytes; + match prop { + $( + $crate::packets::macros::properties_read_match_branch_name!($prop_variant) => $crate::packets::macros::properties_read_match_branch_body!(stream, properties, read_property_bytes, PropertyType::$prop_variant), + )* + e => return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), + } + if read_property_bytes == len { + break; + } + } + + Ok((properties, length_variable_integer + read_property_bytes)) + } + } + + impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + Unpin { + async fn async_write(&self, stream: &mut S) -> Result { + let mut bytes_written = 0; + bytes_written += $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; + $( + $crate::packets::macros::properties_write!(self, bytes_written, stream, PropertyType::$prop_variant); + )* + + Ok(bytes_written) + } + } + + impl $crate::packets::mqtt_trait::WireLength for $name { + fn wire_len(&self) -> usize { + let mut len: usize = 0; + $( + $crate::packets::macros::properties_wire_length!(self, len , PropertyType::$prop_variant); + )* + len + } + } + }; +} + +macro_rules! properties_struct { + ( @ $(#[$attr:meta])* $name:ident { } -> ($($result:tt)*) ) => ( + // $(#[$attr])* + #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] + pub struct $name { + $($result)* + } + ); + ( @ $(#[$attr:meta])* $name:ident { PayloadFormatIndicator, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.2 Payload Format Indicator + /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. + pub payload_format_indicator: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { MessageExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.3 Message Expiry Interval + /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. + pub message_expiry_interval: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ContentType, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.9 Content Type + /// 3 (0x03) Identifier of the Content Type + pub content_type: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ResponseTopic, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.5 Response Topic + /// 8 (0x08) Byte, Identifier of the Response Topic. + pub response_topic: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { CorrelationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.6 Correlation Data + /// 9 (0x09) Byte, Identifier of the Correlation Data. + pub correlation_data: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ListSubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.8 Subscription Identifier + /// 11 (0x0B), Identifier of the Subscription Identifier. + /// Multiple Subscription Identifiers used in the Publish packet. + pub subscription_identifiers: Vec, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { SubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.8 Subscription Identifier + /// 11 (0x0B), Identifier of the Subscription Identifier. + pub subscription_identifier: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { SessionExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.2 Session Expiry Interval + /// 17 (0x11) Byte Identifier of the Session Expiry Interval + pub session_expiry_interval: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { AssignedClientIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.7 Assigned Client Identifier + /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. + pub assigned_client_identifier: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ServerKeepAlive, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.14 Server Keep Alive + /// 19 (0x13) Byte, Identifier of the Server Keep Alive + pub server_keep_alive: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { AuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.17 Authentication Method + /// 21 (0x15) Byte, Identifier of the Authentication Method + pub authentication_method: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { AuthenticationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.18 Authentication Data + /// 22 (0x16) Byte, Identifier of the Authentication Data + pub authentication_data: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { RequestProblemInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.1.2.11.7 Request Problem Information + /// 23 (0x17) Byte, Identifier of the Request Problem Information + pub request_problem_information: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { WillDelayInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.1.3.2.2 Request Problem Information + /// 24 (0x18) Byte, Identifier of the Will Delay Interval. + pub will_delay_interval: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { RequestResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.1.2.11.6 Request Response Information + /// 25 (0x19) Byte, Identifier of the Request Response Information + pub request_response_information: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.15 Response Information + /// 26 (0x1A) Byte, Identifier of the Response Information. + pub response_info: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ServerReference, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.16 Server Reference + /// 28 (0x1C) Byte, Identifier of the Server Reference + pub server_reference: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ReasonString, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.9 Reason String + /// 31 (0x1F) Byte Identifier of the Reason String. + pub reason_string: Option>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { ReceiveMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.3 Receive Maximum + /// 33 (0x21) Byte, Identifier of the Receive Maximum + pub receive_maximum: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { TopicAliasMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.8 Topic Alias Maximum + /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. + pub topic_alias_maximum: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { TopicAlias, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.4 Topic Alias + /// 35 (0x23) Byte, Identifier of the Topic Alias. + pub topic_alias: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { MaximumQos, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.4 Maximum QoS + /// 36 (0x24) Byte, Identifier of the Maximum QoS. + pub maximum_qos: Option<$crate::packets::QoS>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { RetainAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.5 Retain Available + /// 37 (0x25) Byte, Identifier of Retain Available. + pub retain_available: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { UserProperty, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.10 User Property + /// 38 (0x26) Byte, Identifier of User Property. + pub user_properties: Vec<(Box, Box)>, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { MaximumPacketSize, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.6 Maximum Packet Size + /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. + pub maximum_packet_size: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { WildcardSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.11 Wildcard Subscription Available + /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. + pub wildcards_available: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { SubscriptionIdentifierAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.12 Subscription Identifiers Available + /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. + pub subscription_ids_available: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { SharedSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.13 Shared Subscription Available + /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. + pub shared_subscription_available: Option, + )); + ); + ( @ $(#[$attr:meta])* $name:ident { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + compile_error!(concat!("Unknown property: ", stringify!($unknown))); + ); +} + +macro_rules! properties_read_match_branch_body { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::PayloadFormatIndicator) => {{ + if $properties.payload_format_indicator.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::PayloadFormatIndicator, + ))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.payload_format_indicator = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MessageExpiryInterval) => {{ + if $properties.message_expiry_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::MessageExpiryInterval, + ))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.message_expiry_interval = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ContentType) => {{ + if $properties.content_type.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ContentType))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.content_type = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseTopic) => {{ + if $properties.response_topic.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.response_topic = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::CorrelationData) => {{ + if $properties.correlation_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::CorrelationData))); + } + let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.correlation_data = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifier) => {{ + let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_identifier = Some(prop_body as u32); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ListSubscriptionIdentifier) => {{ + let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_identifiers.push(prop_body as u32); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SessionExpiryInterval) => {{ + if $properties.session_expiry_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::SessionExpiryInterval, + ))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.session_expiry_interval = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AssignedClientIdentifier) => {{ + if $properties.assigned_client_identifier.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::AssignedClientIdentifier, + ))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.assigned_client_identifier = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerKeepAlive) => {{ + if $properties.server_keep_alive.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.server_keep_alive = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationMethod) => {{ + if $properties.authentication_method.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::AuthenticationMethod, + ))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.authentication_method = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationData) => {{ + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::AuthenticationData, + ))); + } + let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.authentication_data = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestResponseInformation) => {{ + if $properties.request_response_information.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::RequestResponseInformation, + ))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.request_response_information = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestProblemInformation) => {{ + if $properties.request_problem_information.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::RequestProblemInformation, + ))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.request_problem_information = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WillDelayInterval) => {{ + if $properties.will_delay_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::WillDelayInterval, + ))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.will_delay_interval = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseInformation) => {{ + if $properties.response_info.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::ResponseInformation, + ))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.response_info = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerReference) => {{ + if $properties.server_reference.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerReference))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.server_reference = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReasonString) => {{ + if $properties.reason_string.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.reason_string = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReceiveMaximum) => {{ + if $properties.receive_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.receive_maximum = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAliasMaximum) => {{ + if $properties.topic_alias_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::TopicAliasMaximum, + ))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.topic_alias_maximum = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAlias) => {{ + if $properties.topic_alias.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::MessageExpiryInterval, + ))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.topic_alias = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumQos) => {{ + if $properties.maximum_qos.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MaximumQos))); + } + let (prop_body, read_bytes) = $crate::packets::QoS::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.maximum_qos = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RetainAvailable) => {{ + if $properties.retain_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.retain_available = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::UserProperty) => {{ + let (prop_body_key, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + let (prop_body_value, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + + $properties.user_properties.push((prop_body_key, prop_body_value)) + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumPacketSize) => {{ + if $properties.maximum_packet_size.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.maximum_packet_size = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WildcardSubscriptionAvailable) => {{ + if $properties.wildcards_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::WildcardSubscriptionAvailable, + ))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.wildcards_available = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifierAvailable) => {{ + if $properties.subscription_ids_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::SubscriptionIdentifierAvailable, + ))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_ids_available = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SharedSubscriptionAvailable) => {{ + if $properties.shared_subscription_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::SharedSubscriptionAvailable, + ))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.shared_subscription_available = Some(prop_body); + }}; +} + +macro_rules! properties_read_match_branch_name { + (ListSubscriptionIdentifier) => { + PropertyType::SubscriptionIdentifier + }; + ($name:ident) => { + PropertyType::$name + }; +} + +macro_rules! properties_write { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::PayloadFormatIndicator) => { + if let Some(payload_format_indicator) = &($self.payload_format_indicator) { + $bytes_written += PropertyType::PayloadFormatIndicator.async_write($stream).await?; + $bytes_written += payload_format_indicator.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MessageExpiryInterval) => { + if let Some(message_expiry_interval) = &($self.message_expiry_interval) { + $bytes_written += PropertyType::MessageExpiryInterval.async_write($stream).await?; + $bytes_written += message_expiry_interval.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ContentType) => { + if let Some(content_type) = &($self.content_type) { + $bytes_written += PropertyType::ContentType.async_write($stream).await?; + $bytes_written += content_type.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ResponseTopic) => { + if let Some(response_topic) = &($self.response_topic) { + $bytes_written += PropertyType::ResponseTopic.async_write($stream).await?; + $bytes_written += response_topic.as_ref().async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::CorrelationData) => { + if let Some(correlation_data) = &($self.correlation_data) { + $bytes_written += PropertyType::CorrelationData.async_write($stream).await?; + $bytes_written += correlation_data.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SubscriptionIdentifier) => { + if let Some(sub_id) = &($self.subscription_identifier) { + $bytes_written += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_written += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ListSubscriptionIdentifier) => { + for sub_id in &($self.subscription_identifiers) { + $bytes_written += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_written += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SessionExpiryInterval) => { + if let Some(session_expiry_interval) = &($self.session_expiry_interval) { + $bytes_written += PropertyType::SessionExpiryInterval.async_write($stream).await?; + $bytes_written += session_expiry_interval.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AssignedClientIdentifier) => {}; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ServerKeepAlive) => { + if let Some(server_keep_alive) = &($self.server_keep_alive) { + $bytes_written += PropertyType::ServerKeepAlive.async_write($stream).await?; + $bytes_written += server_keep_alive.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AuthenticationMethod) => { + if let Some(authentication_method) = &($self.authentication_method) { + $bytes_written += PropertyType::AuthenticationMethod.async_write($stream).await?; + $bytes_written += authentication_method.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AuthenticationData) => { + if let Some(authentication_data) = &($self.authentication_data) { + if !authentication_data.is_empty() && ($self.authentication_method).is_some() { + $bytes_written += PropertyType::AuthenticationData.async_write($stream).await?; + $bytes_written += authentication_data.async_write($stream).await?; + } + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RequestProblemInformation) => { + if let Some(request_problem_information) = &($self.request_problem_information) { + $bytes_written += PropertyType::RequestProblemInformation.async_write($stream).await?; + $bytes_written += request_problem_information.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::WillDelayInterval) => { + if let Some(delay_interval) = &($self.will_delay_interval) { + $bytes_written += PropertyType::WillDelayInterval.async_write($stream).await?; + $bytes_written += delay_interval.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RequestResponseInformation) => { + if let Some(request_response_information) = &($self.request_response_information) { + $bytes_written += PropertyType::RequestResponseInformation.async_write($stream).await?; + $bytes_written += request_response_information.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ResponseInformation) => { + if let Some(response_info) = &($self.response_info) { + $bytes_written += PropertyType::ResponseInformation.async_write($stream).await?; + $bytes_written += response_info.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ServerReference) => { + if let Some(server_refrence) = &($self.server_reference) { + $bytes_written += PropertyType::ServerReference.async_write($stream).await?; + server_refrence.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ReasonString) => { + if let Some(reason_string) = &($self.reason_string) { + $bytes_written += PropertyType::ReasonString.async_write($stream).await?; + $bytes_written += reason_string.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ReceiveMaximum) => { + if let Some(receive_maximum) = &($self.receive_maximum) { + $bytes_written += PropertyType::ReceiveMaximum.async_write($stream).await?; + $bytes_written += receive_maximum.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::TopicAliasMaximum) => { + if let Some(topic_alias_maximum) = &($self.topic_alias_maximum) { + $bytes_written += PropertyType::TopicAliasMaximum.async_write($stream).await?; + $bytes_written += topic_alias_maximum.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::TopicAlias) => { + if let Some(topic_alias) = &($self.topic_alias) { + $bytes_written += PropertyType::TopicAlias.async_write($stream).await?; + $bytes_written += topic_alias.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MaximumQos) => { + if let Some(maximum_qos) = &($self.maximum_qos) { + $bytes_written += PropertyType::MaximumQos.async_write($stream).await?; + $bytes_written += maximum_qos.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RetainAvailable) => { + if let Some(retain_available) = &($self.retain_available) { + $bytes_written += PropertyType::RetainAvailable.async_write($stream).await?; + $bytes_written += retain_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::UserProperty) => { + for (key, value) in &($self.user_properties) { + $bytes_written += PropertyType::UserProperty.async_write($stream).await?; + $bytes_written += key.async_write($stream).await?; + $bytes_written += value.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MaximumPacketSize) => { + if let Some(maximum_packet_size) = &($self.maximum_packet_size) { + $bytes_written += PropertyType::MaximumPacketSize.async_write($stream).await?; + $bytes_written += maximum_packet_size.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::WildcardSubscriptionAvailable) => { + if let Some(wildcards_available) = &($self.wildcards_available) { + $bytes_written += PropertyType::WildcardSubscriptionAvailable.async_write($stream).await?; + $bytes_written += wildcards_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SubscriptionIdentifierAvailable) => { + if let Some(subscription_ids_available) = &($self.subscription_ids_available) { + $bytes_written += PropertyType::SubscriptionIdentifierAvailable.async_write($stream).await?; + $bytes_written += subscription_ids_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SharedSubscriptionAvailable) => { + if let Some(shared_subscription_available) = &($self.shared_subscription_available) { + $bytes_written += PropertyType::SharedSubscriptionAvailable.async_write($stream).await?; + $bytes_written += shared_subscription_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_written:ident, $stream:ident, $unknown:ident) => { + compile_error!(concat!("Unknown property: ", stringify!($unknown))); + }; +} + +macro_rules! properties_wire_length { + ($self:ident, $len:ident, PropertyType::PayloadFormatIndicator) => { + if $self.payload_format_indicator.is_some() { + $len += 2; + } + }; + ($self:ident, $len:ident, PropertyType::MessageExpiryInterval) => { + if $self.message_expiry_interval.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::ContentType) => { + if let Some(content_type) = &($self.content_type) { + $len += 1 + content_type.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ResponseTopic) => { + if let Some(response_topic) = &($self.response_topic) { + $len += 1 + response_topic.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::CorrelationData) => { + if let Some(correlation_data) = &($self.correlation_data) { + $len += 1 + correlation_data.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::SubscriptionIdentifier) => { + if let Some(sub_id) = &($self.subscription_identifier) { + $len += 1 + crate::packets::primitive::VariableInteger::variable_integer_len(sub_id); + } + }; + ($self:ident, $len:ident, PropertyType::ListSubscriptionIdentifier) => { + for sub_id in &($self.subscription_identifiers) { + $len += 1 + crate::packets::primitive::VariableInteger::variable_integer_len(sub_id); + } + }; + ($self:ident, $len:ident, PropertyType::SessionExpiryInterval) => { + if $self.session_expiry_interval.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::AssignedClientIdentifier) => { + if let Some(client_id) = $self.assigned_client_identifier.as_ref() { + $len += 1 + client_id.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ServerKeepAlive) => { + if $self.server_keep_alive.is_some() { + $len += 1 + 2; + } + }; + ($self:ident, $len:ident, PropertyType::AuthenticationMethod) => { + if let Some(authentication_method) = &($self.authentication_method) { + $len += 1 + authentication_method.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::AuthenticationData) => { + if let Some(authentication_data) = &($self).authentication_data { + if !authentication_data.is_empty() && $self.authentication_method.is_some() { + $len += 1 + authentication_data.wire_len(); + } + } + }; + ($self:ident, $len:ident, PropertyType::RequestProblemInformation) => { + if $self.request_problem_information.is_some() { + $len += 2; + } + }; + ($self:ident, $len:ident, PropertyType::WillDelayInterval) => { + if $self.will_delay_interval.is_some() { + $len += 5; + } + }; + ($self:ident, $len:ident, PropertyType::RequestResponseInformation) => { + if $self.request_response_information.is_some() { + $len += 2; + } + }; + ($self:ident, $len:ident, PropertyType::ResponseInformation) => { + if let Some(response_info) = &($self.response_info) { + $len += 1 + response_info.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ServerReference) => { + if let Some(server_reference) = &($self.server_reference) { + $len += 1 + server_reference.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ReasonString) => { + if let Some(reason_string) = &($self.reason_string) { + $len += 1 + reason_string.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ReceiveMaximum) => { + if $self.receive_maximum.is_some() { + $len += 1 + 2; + } + }; + ($self:ident, $len:ident, PropertyType::TopicAliasMaximum) => { + if $self.topic_alias_maximum.is_some() { + $len += 1 + 2; + } + }; + ($self:ident, $len:ident, PropertyType::TopicAlias) => { + if $self.topic_alias.is_some() { + $len += 3; + } + }; + ($self:ident, $len:ident, PropertyType::MaximumQos) => { + if $self.maximum_qos.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::RetainAvailable) => { + if $self.retain_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::UserProperty) => { + for (key, value) in &($self.user_properties) { + $len += 1; + $len += key.wire_len(); + $len += value.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::MaximumPacketSize) => { + if $self.maximum_packet_size.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::WildcardSubscriptionAvailable) => { + if $self.wildcards_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::SubscriptionIdentifierAvailable) => { + if $self.subscription_ids_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::SharedSubscriptionAvailable) => { + if $self.shared_subscription_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, $unknown:ident) => { + compile_error!(concat!("Unknown property: ", stringify!($unknown))); + }; +} + +pub(crate) use define_properties; +pub(crate) use properties_read_match_branch_body; +pub(crate) use properties_read_match_branch_name; +pub(crate) use properties_struct; +pub(crate) use properties_wire_length; +pub(crate) use properties_write; diff --git a/mqrstt/src/packets/macros/reason_code_macros.rs b/mqrstt/src/packets/macros/reason_code_macros.rs new file mode 100644 index 0000000..68ec2e4 --- /dev/null +++ b/mqrstt/src/packets/macros/reason_code_macros.rs @@ -0,0 +1,632 @@ +macro_rules! reason_code { + ($name:ident, $($code:ident),*) => { + use tokio::io::AsyncReadExt; + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub enum $name { + #[default] + $($code),* + } + + impl $name { + pub(crate) fn from_u8(val: u8) -> Result { + $crate::packets::macros::reason_code_match!(@ $name, val, { + $($code,)* + } -> ()) + } + + pub(crate) fn to_u8(self) -> u8 { + $crate::packets::macros::reason_code_match_write!(@ $name, self, { + $($code,)* + } -> ()) + } + } + + impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncRead + std::marker::Unpin{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { + let input = stream.read_u8().await?; + let res = Self::from_u8(input)?; + Ok((res, 1)) + } + } + + impl $crate::packets::mqtt_trait::MqttRead for $name { + fn read(buf: &mut bytes::Bytes) -> Result { + if buf.is_empty() { + return Err($crate::packets::error::DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + use bytes::Buf; + let input = buf.get_u8(); + Self::from_u8(input) + } + } + + impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + std::marker::Unpin{ + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let val = self.to_u8(); + stream.write_u8(val).await?; + Ok(1) + } + } + + impl $crate::packets::mqtt_trait::MqttWrite for $name { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), $crate::packets::error::SerializeError> { + use bytes::BufMut; + let val = self.to_u8(); + buf.put_u8(val); + Ok(()) + } + } + + }; +} + +macro_rules! reason_code_match { + ( @ $name:ident, $input:ident, { } -> ($($result:tt)*) ) => ( + match $input { + $($result)* + t => Err($crate::packets::error::DeserializeError::UnknownProperty(t)), + } + ); + ( @ $name:ident, $input:ident, { Success, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x00 => Ok($name::Success), + )) + ); + ( @ $name:ident, $input:ident, { NormalDisconnection, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x00 => Ok($name::NormalDisconnection), + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS0, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x00 => Ok($name::GrantedQoS0), + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS1, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x01 => Ok($name::GrantedQoS1), + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS2, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x02 => Ok($name::GrantedQoS2), + )) + ); + ( @ $name:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x04 => Ok($name::DisconnectWithWillMessage), + )) + ); + ( @ $name:ident, $input:ident, { NoMatchingSubscribers, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x10 => Ok($name::NoMatchingSubscribers), + )) + ); + ( @ $name:ident, $input:ident, { NoSubscriptionExisted, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x11 => Ok($name::NoSubscriptionExisted), + )) + ); + ( @ $name:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x18 => Ok($name::ContinueAuthentication), + )) + ); + ( @ $name:ident, $input:ident, { ReAuthenticate, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x19 => Ok($name::ReAuthenticate), + )) + ); + ( @ $name:ident, $input:ident, { UnspecifiedError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x80 => Ok($name::UnspecifiedError), + )) + ); + ( @ $name:ident, $input:ident, { MalformedPacket, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x81 => Ok($name::MalformedPacket), + )) + ); + ( @ $name:ident, $input:ident, { ProtocolError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x82 => Ok($name::ProtocolError), + )) + ); + ( @ $name:ident, $input:ident, { ImplementationSpecificError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x83 => Ok($name::ImplementationSpecificError), + )) + ); + ( @ $name:ident, $input:ident, { UnsupportedProtocolVersion, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x84 => Ok($name::UnsupportedProtocolVersion), + )) + ); + ( @ $name:ident, $input:ident, { ClientIdentifierNotValid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x85 => Ok($name::ClientIdentifierNotValid), + )) + ); + ( @ $name:ident, $input:ident, { BadUsernameOrPassword, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x86 => Ok($name::BadUsernameOrPassword), + )) + ); + ( @ $name:ident, $input:ident, { NotAuthorized, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x87 => Ok($name::NotAuthorized), + )) + ); + ( @ $name:ident, $input:ident, { ServerUnavailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x88 => Ok($name::ServerUnavailable), + )) + ); + ( @ $name:ident, $input:ident, { ServerBusy, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x89 => Ok($name::ServerBusy), + )) + ); + ( @ $name:ident, $input:ident, { Banned, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8A => Ok($name::Banned), + )) + ); + ( @ $name:ident, $input:ident, { ServerShuttingDown, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8B => Ok($name::ServerShuttingDown), + )) + ); + ( @ $name:ident, $input:ident, { BadAuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8C => Ok($name::BadAuthenticationMethod), + )) + ); + ( @ $name:ident, $input:ident, { KeepAliveTimeout, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8D => Ok($name::KeepAliveTimeout), + )) + ); + ( @ $name:ident, $input:ident, { SessionTakenOver, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8E => Ok($name::SessionTakenOver), + )) + ); + ( @ $name:ident, $input:ident, { TopicFilterInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8F => Ok($name::TopicFilterInvalid), + )) + ); + ( @ $name:ident, $input:ident, { TopicNameInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x90 => Ok($name::TopicNameInvalid), + )) + ); + ( @ $name:ident, $input:ident, { PacketIdentifierInUse, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x91 => Ok($name::PacketIdentifierInUse), + )) + ); + ( @ $name:ident, $input:ident, { PacketIdentifierNotFound, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x92 => Ok($name::PacketIdentifierNotFound), + )) + ); + ( @ $name:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x93 => Ok($name::ReceiveMaximumExceeded), + )) + ); + ( @ $name:ident, $input:ident, { TopicAliasInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x94 => Ok($name::TopicAliasInvalid), + )) + ); + ( @ $name:ident, $input:ident, { PacketTooLarge, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x95 => Ok($name::PacketTooLarge), + )) + ); + ( @ $name:ident, $input:ident, { MessageRateTooHigh, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x96 => Ok($name::MessageRateTooHigh), + )) + ); + ( @ $name:ident, $input:ident, { QuotaExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x97 => Ok($name::QuotaExceeded), + )) + ); + ( @ $name:ident, $input:ident, { AdministrativeAction, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x98 => Ok($name::AdministrativeAction), + )) + ); + ( @ $name:ident, $input:ident, { PayloadFormatInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x99 => Ok($name::PayloadFormatInvalid), + )) + ); + ( @ $name:ident, $input:ident, { RetainNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9A => Ok($name::RetainNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { QosNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9B => Ok($name::QosNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { UseAnotherServer, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9C => Ok($name::UseAnotherServer), + )) + ); + ( @ $name:ident, $input:ident, { ServerMoved, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9D => Ok($name::ServerMoved), + )) + ); + ( @ $name:ident, $input:ident, { SharedSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9E => Ok($name::SharedSubscriptionsNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { ConnectionRateExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9F => Ok($name::ConnectionRateExceeded), + )) + ); + ( @ $name:ident, $input:ident, { MaximumConnectTime, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0xA0 => Ok($name::MaximumConnectTime), + )) + ); + ( @ $name:ident, $input:ident, { SubscriptionIdentifiersNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0xA1 => Ok($name::SubscriptionIdentifiersNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { WildcardSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0xA2 => Ok($name::WildcardSubscriptionsNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + compile_error!(concat!("Unknown reason_code: ", stringify!($unknown))) + ); +} + +macro_rules! reason_code_match_write{ + ( @ $name:ident, $input:ident, { } -> ($($result:tt)*) ) => ( + match $input { + $($result)* + } + ); + ( @ $name:ident, $input:ident, { Success, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::Success => 0x00, + )) + ); + ( @ $name:ident, $input:ident, { NormalDisconnection, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::NormalDisconnection => 0x00, + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS0, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::GrantedQoS0 => 0x00, + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS1, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::GrantedQoS1 => 0x01, + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS2, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::GrantedQoS2 => 0x02, + )) + ); + ( @ $name:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::DisconnectWithWillMessage => 0x04, + )) + ); + ( @ $name:ident, $input:ident, { NoMatchingSubscribers, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::NoMatchingSubscribers => 0x10, + )) + ); + ( @ $name:ident, $input:ident, { NoSubscriptionExisted, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::NoSubscriptionExisted => 0x11, + )) + ); + ( @ $name:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ContinueAuthentication => 0x18, + )) + ); + ( @ $name:ident, $input:ident, { ReAuthenticate, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ReAuthenticate => 0x19, + )) + ); + ( @ $name:ident, $input:ident, { UnspecifiedError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::UnspecifiedError => 0x80, + )) + ); + ( @ $name:ident, $input:ident, { MalformedPacket, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::MalformedPacket => 0x81, + )) + ); + ( @ $name:ident, $input:ident, { ProtocolError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ProtocolError => 0x82, + )) + ); + ( @ $name:ident, $input:ident, { ImplementationSpecificError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ImplementationSpecificError => 0x83, + )) + ); + ( @ $name:ident, $input:ident, { UnsupportedProtocolVersion, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::UnsupportedProtocolVersion => 0x84, + )) + ); + ( @ $name:ident, $input:ident, { ClientIdentifierNotValid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ClientIdentifierNotValid => 0x85, + )) + ); + ( @ $name:ident, $input:ident, { BadUsernameOrPassword, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::BadUsernameOrPassword => 0x86, + )) + ); + ( @ $name:ident, $input:ident, { NotAuthorized, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::NotAuthorized => 0x87, + )) + ); + ( @ $name:ident, $input:ident, { ServerUnavailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ServerUnavailable => 0x88, + )) + ); + ( @ $name:ident, $input:ident, { ServerBusy, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ServerBusy => 0x89, + )) + ); + ( @ $name:ident, $input:ident, { Banned, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::Banned => 0x8A, + )) + ); + ( @ $name:ident, $input:ident, { ServerShuttingDown, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ServerShuttingDown => 0x8B , + )) + ); + ( @ $name:ident, $input:ident, { BadAuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::BadAuthenticationMethod => 0x8C, + )) + ); + + ( @ $name:ident, $input:ident, { KeepAliveTimeout, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::KeepAliveTimeout => 0x8D, + )) + ); + ( @ $name:ident, $input:ident, { SessionTakenOver, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::SessionTakenOver => 0x8E, + )) + ); + ( @ $name:ident, $input:ident, { TopicFilterInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::TopicFilterInvalid => 0x8F, + )) + ); + ( @ $name:ident, $input:ident, { TopicNameInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::TopicNameInvalid => 0x90, + )) + ); + ( @ $name:ident, $input:ident, { PacketIdentifierInUse, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::PacketIdentifierInUse => 0x91, + )) + ); + ( @ $name:ident, $input:ident, { PacketIdentifierNotFound, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::PacketIdentifierNotFound => 0x92, + + )) + ); + ( @ $name:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ReceiveMaximumExceeded => 0x93, + )) + ); + ( @ $name:ident, $input:ident, { TopicAliasInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::TopicAliasInvalid => 0x94, + )) + ); + ( @ $name:ident, $input:ident, { PacketTooLarge, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::PacketTooLarge => 0x95, + )) + ); + ( @ $name:ident, $input:ident, { MessageRateTooHigh, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::MessageRateTooHigh => 0x96, + )) + ); + ( @ $name:ident, $input:ident, { QuotaExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::QuotaExceeded => 0x97, + )) + ); + ( @ $name:ident, $input:ident, { AdministrativeAction, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::AdministrativeAction => 0x98, + )) + ); + ( @ $name:ident, $input:ident, { PayloadFormatInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::PayloadFormatInvalid => 0x99, + )) + ); + ( @ $name:ident, $input:ident, { RetainNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::RetainNotSupported => 0x9A, + )) + ); + ( @ $name:ident, $input:ident, { QosNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::QosNotSupported => 0x9B, + )) + ); + ( @ $name:ident, $input:ident, { UseAnotherServer, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::UseAnotherServer => 0x9C, + )) + ); + ( @ $name:ident, $input:ident, { ServerMoved, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ServerMoved => 0x9D, + )) + ); + ( @ $name:ident, $input:ident, { SharedSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::SharedSubscriptionsNotSupported => 0x9E, + )) + ); + ( @ $name:ident, $input:ident, { ConnectionRateExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::ConnectionRateExceeded => 0x9F, + )) + ); + ( @ $name:ident, $input:ident, { MaximumConnectTime, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::MaximumConnectTime => 0xA0, + )) + ); + ( @ $name:ident, $input:ident, { SubscriptionIdentifiersNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::SubscriptionIdentifiersNotSupported => 0xA1, + )) + ); + ( @ $name:ident, $input:ident, { WildcardSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( + $($result)* + $name::WildcardSubscriptionsNotSupported => 0xA2, + )) + ); + + ( @ $name:ident, $input:ident, { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + compile_error!(concat!("Unknown reason_code: ", stringify!($unknown))) + ); +} + +pub(crate) use reason_code; +pub(crate) use reason_code_match; +pub(crate) use reason_code_match_write; diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 8675d1b..5921c21 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -1,6 +1,7 @@ pub mod error; -pub mod mqtt_traits; -pub mod reason_codes; +pub(crate) mod mqtt_trait; + +mod macros; mod auth; mod connack; @@ -16,6 +17,11 @@ mod subscribe; mod unsuback; mod unsubscribe; +mod primitive; +use error::{ReadError, WriteError}; +use mqtt_trait::{PacketAsyncRead, PacketAsyncWrite}; +pub use primitive::*; + pub use auth::*; pub use connack::*; pub use connect::*; @@ -30,519 +36,13 @@ pub use subscribe::*; pub use unsuback::*; pub use unsubscribe::*; -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use core::slice::Iter; +use bytes::{BufMut, Bytes, BytesMut}; use std::fmt::Display; -use self::error::{DeserializeError, ReadBytes, SerializeError}; -use self::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; - -/// Protocol version -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum ProtocolVersion { - V5, -} - -impl MqttWrite for ProtocolVersion { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u8(5u8); - Ok(()) - } -} - -impl MqttRead for ProtocolVersion { - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientDataForProtocolVersion); - } - - match buf.get_u8() { - 3 => Err(DeserializeError::UnsupportedProtocolVersion), - 4 => Err(DeserializeError::UnsupportedProtocolVersion), - 5 => Ok(ProtocolVersion::V5), - _ => Err(DeserializeError::UnknownProtocolVersion), - } - } -} - -/// Quality of service -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum QoS { - #[default] - AtMostOnce = 0, - AtLeastOnce = 1, - ExactlyOnce = 2, -} -impl QoS { - pub fn from_u8(value: u8) -> Result { - match value { - 0 => Ok(QoS::AtMostOnce), - 1 => Ok(QoS::AtLeastOnce), - 2 => Ok(QoS::ExactlyOnce), - _ => Err(DeserializeError::UnknownQoS(value)), - } - } - pub fn into_u8(self) -> u8 { - match self { - QoS::AtMostOnce => 0, - QoS::AtLeastOnce => 1, - QoS::ExactlyOnce => 2, - } - } -} - -impl MqttRead for QoS { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("QoS".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0 => Ok(QoS::AtMostOnce), - 1 => Ok(QoS::AtLeastOnce), - 2 => Ok(QoS::ExactlyOnce), - q => Err(DeserializeError::UnknownQoS(q)), - } - } -} - -impl MqttWrite for QoS { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - let val = match self { - QoS::AtMostOnce => 0, - QoS::AtLeastOnce => 1, - QoS::ExactlyOnce => 2, - }; - buf.put_u8(val); - Ok(()) - } -} - -impl MqttRead for Box { - #[inline] - fn read(buf: &mut Bytes) -> Result { - let content = Bytes::read(buf)?; - - match String::from_utf8(content.to_vec()) { - Ok(s) => Ok(s.into()), - Err(e) => Err(DeserializeError::Utf8Error(e)), - } - } -} - -impl MqttWrite for Box { - #[inline(always)] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - self.as_ref().write(buf) - } -} - -impl WireLength for Box { - #[inline(always)] - fn wire_len(&self) -> usize { - self.as_ref().wire_len() - } -} - -impl MqttWrite for &str { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.len() as u16); - buf.extend(self.as_bytes()); - Ok(()) - } -} - -impl WireLength for &str { - #[inline(always)] - fn wire_len(&self) -> usize { - self.len() + 2 - } -} - -impl MqttRead for String { - #[inline] - fn read(buf: &mut Bytes) -> Result { - let content = Bytes::read(buf)?; - - match String::from_utf8(content.to_vec()) { - Ok(s) => Ok(s), - Err(e) => Err(DeserializeError::Utf8Error(e)), - } - } -} - -impl MqttWrite for String { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - if self.len() > 65535 { - return Err(SerializeError::StringTooLong(self.len())); - } - - buf.put_u16(self.len() as u16); - buf.extend(self.as_bytes()); - Ok(()) - } -} - -impl WireLength for String { - #[inline(always)] - fn wire_len(&self) -> usize { - self.len() + 2 - } -} - -impl MqttRead for Bytes { - #[inline] - fn read(buf: &mut Bytes) -> Result { - let len = buf.get_u16() as usize; - - if len > buf.len() { - return Err(DeserializeError::InsufficientData("Bytes".to_string(), buf.len(), len)); - } - - Ok(buf.split_to(len)) - } -} - -impl MqttWrite for Bytes { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.len() as u16); - buf.extend(self); - - Ok(()) - } -} - -impl WireLength for Bytes { - #[inline(always)] - fn wire_len(&self) -> usize { - self.len() + 2 - } -} - -impl MqttRead for bool { - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("bool".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0 => Ok(false), - 1 => Ok(true), - _ => Err(error::DeserializeError::MalformedPacket), - } - } -} - -impl MqttWrite for bool { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - if *self { - buf.put_u8(1); - Ok(()) - } else { - buf.put_u8(0); - Ok(()) - } - } -} - -impl MqttRead for u8 { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("u8".to_string(), 0, 1)); - } - Ok(buf.get_u8()) - } -} - -impl MqttRead for u16 { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.len() < 2 { - return Err(DeserializeError::InsufficientData("u16".to_string(), buf.len(), 2)); - } - Ok(buf.get_u16()) - } -} - -impl MqttWrite for u16 { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u16(*self); - Ok(()) - } -} - -impl MqttRead for u32 { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.len() < 4 { - return Err(DeserializeError::InsufficientData("u32".to_string(), buf.len(), 4)); - } - Ok(buf.get_u32()) - } -} - -impl MqttWrite for u32 { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u32(*self); - Ok(()) - } -} - -pub fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { - let mut integer = 0; - let mut length = 0; - - for i in 0..4 { - if let Some(byte) = buf.next() { - length += 1; - integer += (*byte as usize & 0x7f) << (7 * i); - - if (*byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } - } else { - return Err(ReadBytes::InsufficientBytes(1)); - } - } - Err(ReadBytes::Err(DeserializeError::MalformedPacket)) -} - -pub fn read_variable_integer(buf: &mut Bytes) -> Result<(usize, usize), DeserializeError> { - let mut integer = 0; - let mut length = 0; - - for i in 0..4 { - if buf.is_empty() { - return Err(DeserializeError::MalformedPacket); - } - length += 1; - let byte = buf.get_u8(); - - integer += (byte as usize & 0x7f) << (7 * i); - - if (byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } - } - Err(DeserializeError::MalformedPacket) -} - -pub fn write_variable_integer(buf: &mut BytesMut, integer: usize) -> Result<(), SerializeError> { - if integer > 268_435_455 { - return Err(SerializeError::VariableIntegerOverflow(integer)); - } - - let mut write = integer; - - for _ in 0..4 { - let mut byte = (write % 128) as u8; - write /= 128; - if write > 0 { - byte |= 128; - } - buf.put_u8(byte); - if write == 0 { - return Ok(()); - } - } - Err(SerializeError::VariableIntegerOverflow(integer)) -} - -pub fn variable_integer_len(integer: usize) -> usize { - if integer >= 2_097_152 { - 4 - } else if integer >= 16_384 { - 3 - } else if integer >= 128 { - 2 - } else { - 1 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PropertyType { - PayloadFormatIndicator = 1, - MessageExpiryInterval = 2, - ContentType = 3, - ResponseTopic = 8, - CorrelationData = 9, - SubscriptionIdentifier = 11, - SessionExpiryInterval = 17, - AssignedClientIdentifier = 18, - ServerKeepAlive = 19, - AuthenticationMethod = 21, - AuthenticationData = 22, - RequestProblemInformation = 23, - WillDelayInterval = 24, - RequestResponseInformation = 25, - ResponseInformation = 26, - ServerReference = 28, - ReasonString = 31, - ReceiveMaximum = 33, - TopicAliasMaximum = 34, - TopicAlias = 35, - MaximumQos = 36, - RetainAvailable = 37, - UserProperty = 38, - MaximumPacketSize = 39, - WildcardSubscriptionAvailable = 40, - SubscriptionIdentifierAvailable = 41, - SharedSubscriptionAvailable = 42, -} - -impl MqttRead for PropertyType { - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PropertyType".to_string(), 0, 1)); - } - - match buf.get_u8() { - 1 => Ok(Self::PayloadFormatIndicator), - 2 => Ok(Self::MessageExpiryInterval), - 3 => Ok(Self::ContentType), - 8 => Ok(Self::ResponseTopic), - 9 => Ok(Self::CorrelationData), - 11 => Ok(Self::SubscriptionIdentifier), - 17 => Ok(Self::SessionExpiryInterval), - 18 => Ok(Self::AssignedClientIdentifier), - 19 => Ok(Self::ServerKeepAlive), - 21 => Ok(Self::AuthenticationMethod), - 22 => Ok(Self::AuthenticationData), - 23 => Ok(Self::RequestProblemInformation), - 24 => Ok(Self::WillDelayInterval), - 25 => Ok(Self::RequestResponseInformation), - 26 => Ok(Self::ResponseInformation), - 28 => Ok(Self::ServerReference), - 31 => Ok(Self::ReasonString), - 33 => Ok(Self::ReceiveMaximum), - 34 => Ok(Self::TopicAliasMaximum), - 35 => Ok(Self::TopicAlias), - 36 => Ok(Self::MaximumQos), - 37 => Ok(Self::RetainAvailable), - 38 => Ok(Self::UserProperty), - 39 => Ok(Self::MaximumPacketSize), - 40 => Ok(Self::WildcardSubscriptionAvailable), - 41 => Ok(Self::SubscriptionIdentifierAvailable), - 42 => Ok(Self::SharedSubscriptionAvailable), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PropertyType { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - let val = match self { - Self::PayloadFormatIndicator => 1, - Self::MessageExpiryInterval => 2, - Self::ContentType => 3, - Self::ResponseTopic => 8, - Self::CorrelationData => 9, - Self::SubscriptionIdentifier => 11, - Self::SessionExpiryInterval => 17, - Self::AssignedClientIdentifier => 18, - Self::ServerKeepAlive => 19, - Self::AuthenticationMethod => 21, - Self::AuthenticationData => 22, - Self::RequestProblemInformation => 23, - Self::WillDelayInterval => 24, - Self::RequestResponseInformation => 25, - Self::ResponseInformation => 26, - Self::ServerReference => 28, - Self::ReasonString => 31, - Self::ReceiveMaximum => 33, - Self::TopicAliasMaximum => 34, - Self::TopicAlias => 35, - Self::MaximumQos => 36, - Self::RetainAvailable => 37, - Self::UserProperty => 38, - Self::MaximumPacketSize => 39, - Self::WildcardSubscriptionAvailable => 40, - Self::SubscriptionIdentifierAvailable => 41, - Self::SharedSubscriptionAvailable => 42, - }; - - buf.put_u8(val); - Ok(()) - } -} - -impl PropertyType { - pub fn from_u8(value: u8) -> Result { - match value { - 1 => Ok(Self::PayloadFormatIndicator), - 2 => Ok(Self::MessageExpiryInterval), - 3 => Ok(Self::ContentType), - 8 => Ok(Self::ResponseTopic), - 9 => Ok(Self::CorrelationData), - 11 => Ok(Self::SubscriptionIdentifier), - 17 => Ok(Self::SessionExpiryInterval), - 18 => Ok(Self::AssignedClientIdentifier), - 19 => Ok(Self::ServerKeepAlive), - 21 => Ok(Self::AuthenticationMethod), - 22 => Ok(Self::AuthenticationData), - 23 => Ok(Self::RequestProblemInformation), - 24 => Ok(Self::WillDelayInterval), - 25 => Ok(Self::RequestResponseInformation), - 26 => Ok(Self::ResponseInformation), - 28 => Ok(Self::ServerReference), - 31 => Ok(Self::ReasonString), - 33 => Ok(Self::ReceiveMaximum), - 34 => Ok(Self::TopicAliasMaximum), - 35 => Ok(Self::TopicAlias), - 36 => Ok(Self::MaximumQos), - 37 => Ok(Self::RetainAvailable), - 38 => Ok(Self::UserProperty), - 39 => Ok(Self::MaximumPacketSize), - 40 => Ok(Self::WildcardSubscriptionAvailable), - 41 => Ok(Self::SubscriptionIdentifierAvailable), - 42 => Ok(Self::SharedSubscriptionAvailable), - _ => Err("Unkown property type".to_string()), - } - } - pub fn to_u8(self) -> u8 { - match self { - Self::PayloadFormatIndicator => 1, - Self::MessageExpiryInterval => 2, - Self::ContentType => 3, - Self::ResponseTopic => 8, - Self::CorrelationData => 9, - Self::SubscriptionIdentifier => 11, - Self::SessionExpiryInterval => 17, - Self::AssignedClientIdentifier => 18, - Self::ServerKeepAlive => 19, - Self::AuthenticationMethod => 21, - Self::AuthenticationData => 22, - Self::RequestProblemInformation => 23, - Self::WillDelayInterval => 24, - Self::RequestResponseInformation => 25, - Self::ResponseInformation => 26, - Self::ServerReference => 28, - Self::ReasonString => 31, - Self::ReceiveMaximum => 33, - Self::TopicAliasMaximum => 34, - Self::TopicAlias => 35, - Self::MaximumQos => 36, - Self::RetainAvailable => 37, - Self::UserProperty => 38, - Self::MaximumPacketSize => 39, - Self::WildcardSubscriptionAvailable => 40, - Self::SubscriptionIdentifierAvailable => 41, - Self::SharedSubscriptionAvailable => 42, - } - } -} - -// ==================== Packets ==================== +use self::error::{DeserializeError, SerializeError}; +use self::mqtt_trait::{PacketRead, PacketWrite, WireLength}; +/// Enum to bundle the different MQTT packets. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Packet { Connect(Connect), @@ -583,17 +83,17 @@ impl Packet { } } - pub fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + pub(crate) fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { match self { Packet::Connect(p) => { buf.put_u8(0b0001_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::ConnAck(p) => { buf.put_u8(0b0010_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::Publish(p) => { @@ -608,45 +108,48 @@ impl Packet { first_byte |= 0b0001; } buf.put_u8(first_byte); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubAck(p) => { buf.put_u8(0b0100_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubRec(p) => { buf.put_u8(0b0101_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubRel(p) => { buf.put_u8(0b0110_0010); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubComp(p) => { buf.put_u8(0b0111_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::Subscribe(p) => { buf.put_u8(0b1000_0010); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } - Packet::SubAck(_) => { - unreachable!() + Packet::SubAck(p) => { + buf.put_u8(0b1001_0000); + p.wire_len().write_variable_integer(buf)?; + p.write(buf)?; } Packet::Unsubscribe(p) => { buf.put_u8(0b1010_0010); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } - Packet::UnsubAck(_) => { - unreachable!(); + Packet::UnsubAck(p) => { buf.put_u8(0b1011_0000); + p.wire_len().write_variable_integer(buf)?; + p.write(buf)?; } Packet::PingReq => { buf.put_u8(0b1100_0000); @@ -658,19 +161,115 @@ impl Packet { } Packet::Disconnect(p) => { buf.put_u8(0b1110_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::Auth(p) => { buf.put_u8(0b1111_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } } Ok(()) } - pub fn read(header: FixedHeader, buf: Bytes) -> Result { + pub(crate) async fn async_write(&self, stream: &mut S) -> Result + where + S: tokio::io::AsyncWrite + Unpin, + { + use tokio::io::AsyncWriteExt; + let mut written = 1; + match self { + Packet::Connect(p) => { + stream.write_u8(0b0001_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::ConnAck(p) => { + stream.write_u8(0b0010_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::Publish(p) => { + let mut first_byte = 0b0011_0000u8; + if p.dup { + first_byte |= 0b1000; + } + + first_byte |= p.qos.into_u8() << 1; + + if p.retain { + first_byte |= 0b0001; + } + stream.write_u8(first_byte).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubAck(p) => { + stream.write_u8(0b0100_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubRec(p) => { + stream.write_u8(0b0101_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubRel(p) => { + stream.write_u8(0b0110_0010).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubComp(p) => { + stream.write_u8(0b0111_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::Subscribe(p) => { + stream.write_u8(0b1000_0010).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::SubAck(p) => { + stream.write_u8(0b1001_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::Unsubscribe(p) => { + stream.write_u8(0b1010_0010).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::UnsubAck(p) => { + stream.write_u8(0b1011_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PingReq => { + stream.write_u8(0b1100_0000).await?; + stream.write_u8(0).await?; // Variable header length. + written += 1; + } + Packet::PingResp => { + stream.write_u8(0b1101_0000).await?; + stream.write_u8(0).await?; // Variable header length. + written += 1; + } + Packet::Disconnect(p) => { + stream.write_u8(0b1110_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::Auth(p) => { + stream.write_u8(0b1111_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + } + Ok(written) + } + + pub(crate) fn read_packet(header: FixedHeader, buf: Bytes) -> Result { let packet = match header.packet_type { PacketType::Connect => Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?), PacketType::ConnAck => Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?), @@ -691,7 +290,46 @@ impl Packet { Ok(packet) } - pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> { + async fn async_read_packet(header: FixedHeader, stream: &mut S) -> Result + where + S: tokio::io::AsyncRead + Unpin, + { + let packet = match header.packet_type { + PacketType::Connect => Packet::Connect(Connect::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::ConnAck => Packet::ConnAck(ConnAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Publish => Packet::Publish(Publish::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubAck => Packet::PubAck(PubAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubRec => Packet::PubRec(PubRec::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubRel => Packet::PubRel(PubRel::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubComp => Packet::PubComp(PubComp::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Subscribe => Packet::Subscribe(Subscribe::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::SubAck => Packet::SubAck(SubAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect(Disconnect::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Auth => Packet::Auth(Auth::async_read(header.flags, header.remaining_length, stream).await?.0), + }; + Ok(packet) + } + + pub async fn async_read(stream: &mut S) -> Result + where + S: tokio::io::AsyncRead + Unpin, + { + let (header, _) = FixedHeader::async_read(stream).await?; + + #[cfg(feature = "logs")] + tracing::trace!("Read packet header: {:?}", header); + + Packet::async_read_packet(header, stream).await + } + + pub fn read(buffer: &mut BytesMut) -> Result> { + use bytes::Buf; + use error::ReadBytes; + let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?; if header.remaining_length + header_length > buffer.len() { return Err(ReadBytes::InsufficientBytes(header.remaining_length + header_length - buffer.len())); @@ -700,7 +338,7 @@ impl Packet { let buf = buffer.split_to(header.remaining_length); - Ok(Packet::read(header, buf.into())?) + Ok(Packet::read_packet(header, buf.into())?) } } @@ -734,39 +372,25 @@ impl Display for Packet { } } -// 2.1.1 Fixed Header -// ``` -// 7 3 0 -// +--------------------------+--------------------------+ -// byte 1 | MQTT Control Packet Type | Flags for Packet type | -// +--------------------------+--------------------------+ -// | Remaining Length | -// +-----------------------------------------------------+ -// -// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 -// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] -pub struct FixedHeader { - pub packet_type: PacketType, - pub flags: u8, - pub remaining_length: usize, -} - -impl FixedHeader { - pub fn read_fixed_header(mut header: Iter) -> Result<(Self, usize), ReadBytes> { - if header.len() < 2 { - return Err(ReadBytes::InsufficientBytes(2 - header.len())); +impl WireLength for Packet { + fn wire_len(&self) -> usize { + match self { + Packet::Connect(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::ConnAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Publish(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubRec(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubRel(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubComp(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Subscribe(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::SubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Unsubscribe(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::UnsubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PingReq => 2, + Packet::PingResp => 2, + Packet::Disconnect(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Auth(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), } - - let mut header_length = 1; - let first_byte = header.next().unwrap(); - - let (packet_type, flags) = PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; - - let (remaining_length, length) = read_fixed_header_rem_len(header)?; - header_length += length; - - Ok((Self { packet_type, flags, remaining_length }, header_length)) } } @@ -790,6 +414,7 @@ pub enum PacketType { Auth, } impl PacketType { + #[inline] const fn from_first_byte(value: u8) -> Result<(Self, u8), DeserializeError> { match (value >> 4, value & 0x0f) { (0b0001, 0) => Ok((PacketType::Connect, 0)), @@ -820,207 +445,168 @@ impl std::fmt::Display for PacketType { #[cfg(test)] mod tests { - use bytes::{Bytes, BytesMut}; - use crate::packets::connack::{ConnAck, ConnAckFlags, ConnAckProperties}; - use crate::packets::disconnect::{Disconnect, DisconnectProperties}; - use crate::packets::QoS; + use bytes::BytesMut; - use crate::packets::publish::{Publish, PublishProperties}; - use crate::packets::pubrel::{PubRel, PubRelProperties}; - use crate::packets::reason_codes::{ConnAckReasonCode, DisconnectReasonCode, PubRelReasonCode}; use crate::packets::Packet; - #[test] - fn test_connack_read() { - let connack = [ - 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, - ]; - let mut buf = BytesMut::new(); - buf.extend(connack); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = ConnAck { - connack_flags: ConnAckFlags { session_present: true }, - reason_code: ConnAckReasonCode::Success, - connack_properties: ConnAckProperties { - session_expiry_interval: None, - receive_maximum: None, - maximum_qos: None, - retain_available: Some(true), - maximum_packet_size: Some(1048576), - assigned_client_id: None, - topic_alias_maximum: Some(65535), - reason_string: None, - user_properties: vec![], - wildcards_available: Some(true), - subscription_ids_available: Some(true), - shared_subscription_available: Some(true), - server_keep_alive: None, - response_info: None, - server_reference: None, - authentication_method: None, - authentication_data: None, - }, - }; - - assert_eq!(Packet::ConnAck(expected), res); - } - - #[test] - fn test_disconnect_read() { - let packet = [0xe0, 0x02, 0x8e, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = Disconnect { - reason_code: DisconnectReasonCode::SessionTakenOver, - properties: DisconnectProperties { - session_expiry_interval: None, - reason_string: None, - user_properties: vec![], - server_reference: None, - }, - }; + use crate::tests::test_packets::*; + + #[rstest::rstest] + #[case::connect_case(connect_case())] + #[case::ping_req_case(ping_req_case().1)] + #[case::ping_resp_case(ping_resp_case().1)] + #[case::connack_case(connack_case().1)] + #[case::create_subscribe_packet(create_subscribe_packet(1))] + #[case::create_subscribe_packet(create_subscribe_packet(65335))] + #[case::create_puback_packet(create_puback_packet(1))] + #[case::create_puback_packet(create_puback_packet(65335))] + #[case::create_disconnect_packet(create_disconnect_packet())] + #[case::create_connack_packet(create_connack_packet(true))] + #[case::create_connack_packet(create_connack_packet(false))] + #[case::publish_packet_1(publish_packet_1())] + #[case::publish_packet_2(publish_packet_2())] + #[case::publish_packet_3(publish_packet_3())] + #[case::publish_packet_4(publish_packet_4())] + #[case::create_empty_publish_packet(create_empty_publish_packet())] + #[case::subscribe(subscribe_case())] + #[case::suback(suback_case())] + #[case::unsubscribe(unsubscribe_case())] + #[case::unsuback(unsuback_case())] + #[case::pubcomp_case(pubcomp_case())] + #[case::pubrec_case(pubrec_case())] + #[case::pubrec_case(pubrel_case2())] + #[case::auth_case(auth_case())] + fn test_write_read_write_read_cases(#[case] packet: Packet) { + use crate::packets::WireLength; + + let mut buffer = BytesMut::new(); - assert_eq!(Packet::Disconnect(expected), res); - } - - #[test] - fn test_pingreq_read_write() { - let packet = [0xc0, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - assert_eq!(Packet::PingReq, res); - - buf.clear(); - Packet::PingReq.write(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), packet); - } - - #[test] - fn test_pingresp_read_write() { - let packet = [0xd0, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - assert_eq!(Packet::PingResp, res); - - buf.clear(); - Packet::PingResp.write(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), packet); - } - - #[test] - fn test_publish_read() { - let packet = [ - 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, 0x01, 0x09, 0x00, - 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, - ]; - - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = Publish { - dup: false, - qos: QoS::ExactlyOnce, - retain: true, - topic: "test/123/test/blabla".into(), - packet_identifier: Some(13779), - publish_properties: PublishProperties { - payload_format_indicator: Some(1), - message_expiry_interval: None, - topic_alias: None, - response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), - subscription_identifier: vec![1], - user_properties: vec![], - content_type: None, - }, - payload: Bytes::from_static(b""), - }; - - assert_eq!(Packet::Publish(expected), res); - } - - #[test] - fn test_pubrel_read_write() { - let bytes = [0x62, 0x03, 0x35, 0xd3, 0x00]; + packet.write(&mut buffer).unwrap(); + let wire_len = packet.wire_len(); + assert_eq!(wire_len, buffer.len()); + + // dbg!(wire_len); + + // let a: Vec<_> = buffer.iter().map(|f| *f as u16).collect(); + // println!("{:?}", a); + + let res1 = Packet::read(&mut buffer).unwrap(); + + assert_eq!(packet, res1); + + let mut buffer = BytesMut::new(); + res1.write(&mut buffer).unwrap(); + let res2 = Packet::read(&mut buffer).unwrap(); + + assert_eq!(res1, res2); + } + + #[rstest::rstest] + #[case::connect_case(connect_case())] + #[case::ping_req_case(ping_req_case().1)] + #[case::ping_resp_case(ping_resp_case().1)] + #[case::connack_case(connack_case().1)] + #[case::create_subscribe_packet(create_subscribe_packet(1))] + #[case::create_subscribe_packet(create_subscribe_packet(65335))] + #[case::create_puback_packet(create_puback_packet(1))] + #[case::create_puback_packet(create_puback_packet(65335))] + #[case::create_disconnect_packet(create_disconnect_packet())] + #[case::create_connack_packet(create_connack_packet(true))] + #[case::create_connack_packet(create_connack_packet(false))] + #[case::publish_packet_1(publish_packet_1())] + #[case::publish_packet_2(publish_packet_2())] + #[case::publish_packet_3(publish_packet_3())] + #[case::publish_packet_4(publish_packet_4())] + #[case::create_empty_publish_packet(create_empty_publish_packet())] + #[case::subscribe(subscribe_case())] + #[case::suback(suback_case())] + #[case::unsubscribe(unsubscribe_case())] + #[case::unsuback(unsuback_case())] + #[case::pubcomp_case(pubcomp_case())] + #[case::pubrec_case(pubrec_case())] + #[case::pubrec_case(pubrel_case2())] + #[case::auth_case(auth_case())] + #[tokio::test] + async fn test_async_write_read_write_read_cases(#[case] packet: Packet) { + use crate::packets::WireLength; + + let mut buffer = Vec::with_capacity(1000); + let res = packet.async_write(&mut buffer).await.unwrap(); + + let wire_len = packet.wire_len(); + + assert_eq!(res, buffer.len()); + assert_eq!(wire_len, buffer.len()); + + let mut buf = buffer.as_slice(); + + let res1 = Packet::async_read(&mut buf).await.unwrap(); + + pretty_assertions::assert_eq!(packet, res1); + } + + #[rstest::rstest] + #[case::disconnect(disconnect_case())] + #[case::ping_req(ping_req_case())] + #[case::ping_resp(ping_resp_case())] + #[case::publish(publish_case())] + #[case::pubrel(pubrel_case())] + #[case::pubrel_smallest(pubrel_smallest_case())] + fn test_read_write_cases(#[case] (bytes, expected_packet): (&[u8], Packet)) { let mut buffer = BytesMut::from_iter(bytes); - let res = Packet::read_from_buffer(&mut buffer); + let res = Packet::read(&mut buffer); assert!(res.is_ok()); let packet = res.unwrap(); - let expected = PubRel { - packet_identifier: 13779, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties { - reason_string: None, - user_properties: vec![], - }, - }; - - assert_eq!(packet, Packet::PubRel(expected)); + assert_eq!(packet, expected_packet); buffer.clear(); packet.write(&mut buffer).unwrap(); - // The input is not in the smallest possible format but when writing we do expect it to be in the smallest possible format. - assert_eq!(buffer.to_vec(), [0x62, 0x02, 0x35, 0xd3].to_vec()) + assert_eq!(buffer.to_vec(), bytes.to_vec()) } - #[test] - fn test_pubrel_read_smallest_format() { - let bytes = [0x62, 0x02, 0x35, 0xd3]; - - let mut buffer = BytesMut::from_iter(bytes); + #[rstest::rstest] + #[case::disconnect(disconnect_case())] + #[case::ping_req(ping_req_case())] + #[case::ping_resp(ping_resp_case())] + #[case::publish(publish_case())] + #[case::pubrel(pubrel_case())] + #[case::pubrel_smallest(pubrel_smallest_case())] + #[tokio::test] + async fn test_async_read_write(#[case] (mut bytes, expected_packet): (&[u8], Packet)) { + let input = bytes.to_vec(); - let res = Packet::read_from_buffer(&mut buffer); + let res = Packet::async_read(&mut bytes).await; + dbg!(&res); assert!(res.is_ok()); let packet = res.unwrap(); - let expected = PubRel { - packet_identifier: 13779, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties { - reason_string: None, - user_properties: vec![], - }, - }; - - assert_eq!(packet, Packet::PubRel(expected)); + assert_eq!(packet, expected_packet); - buffer.clear(); + let mut out = Vec::with_capacity(1000); - packet.write(&mut buffer).unwrap(); + packet.async_write(&mut out).await.unwrap(); - assert_eq!(buffer.to_vec(), bytes.to_vec()) + assert_eq!(out, input) } + + // #[rstest::rstest] + // #[case(&[59, 1, 0, 59])] + // #[case(&[16, 14, 0, 4, 77, 81, 84, 84, 5, 247, 247, 252, 1, 17, 247, 247, 247])] + // fn test_read_error(#[case] bytes: &[u8]) { + // let mut buffer = BytesMut::from_iter(bytes); + + // let res = Packet::read_from_buffer(&mut buffer); + + // assert!(res.is_err()); + // } } diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs new file mode 100644 index 0000000..f666753 --- /dev/null +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -0,0 +1,62 @@ +mod primitive_impl; + +use std::future::Future; + +use bytes::{Bytes, BytesMut}; + +pub(crate) trait PacketRead: Sized { + fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; +} + +pub(crate) trait PacketAsyncRead: Sized +where + S: tokio::io::AsyncRead + Unpin, +{ + fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; +} + +pub(crate) trait PacketAsyncWrite: Sized +where + S: tokio::io::AsyncWriteExt + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl Future>; +} + +pub(crate) trait PacketWrite: Sized { + fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; +} + +pub(crate) trait WireLength { + fn wire_len(&self) -> usize; +} + +pub(crate) trait MqttRead: Sized { + fn read(buf: &mut Bytes) -> Result; +} +pub(crate) trait MqttAsyncRead: Sized { + /// Reads `Self` from the provided stream. + /// Returns the deserialized instance and the number of bytes read from the stream. + fn async_read(stream: &mut S) -> impl Future>; +} + +pub trait MqttWrite: Sized { + fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; +} + +impl MqttWrite for &T +where + T: MqttWrite, +{ + fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError> { + ::write(self, buf) + } +} +pub(crate) trait MqttAsyncWrite: Sized { + /// Write `Self` to the provided stream. + /// Returns the deserialized instance and the number of bytes read from the stream. + fn async_write(&self, stream: &mut S) -> impl Future>; +} + +pub trait PacketValidation: Sized { + fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; +} diff --git a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs new file mode 100644 index 0000000..b0612cb --- /dev/null +++ b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs @@ -0,0 +1,398 @@ +use tokio::io::AsyncWriteExt; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::error::{DeserializeError, ReadError, SerializeError}; +use crate::packets::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, WireLength}; + +use super::MqttAsyncWrite; + +impl MqttRead for Box { + #[inline] + fn read(buf: &mut Bytes) -> Result { + let content = Bytes::read(buf)?; + + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok(s.into()), + Err(e) => Err(DeserializeError::Utf8Error(e)), + } + } +} + +impl MqttAsyncRead for Box +where + S: tokio::io::AsyncRead + std::marker::Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let (content, read_bytes) = Vec::async_read(stream).await?; + match String::from_utf8(content) { + Ok(s) => Ok((s.into(), read_bytes)), + Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), + } + } +} + +impl MqttWrite for Box { + #[inline(always)] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.as_ref().write(buf) + } +} + +impl MqttAsyncWrite for Box +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) + } +} + +impl WireLength for Box { + #[inline(always)] + fn wire_len(&self) -> usize { + self.as_ref().wire_len() + } +} + +impl MqttWrite for &str { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.len() as u16); + buf.extend(self.as_bytes()); + Ok(()) + } +} + +impl MqttAsyncWrite for &str +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) + } +} + +impl WireLength for &str { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} + +impl MqttRead for String { + #[inline] + fn read(buf: &mut Bytes) -> Result { + let content = Bytes::read(buf)?; + + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok(s), + Err(e) => Err(DeserializeError::Utf8Error(e)), + } + } +} + +impl MqttAsyncRead for String +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + let (content, read_bytes) = Bytes::async_read(buf).await?; + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok((s, read_bytes)), + Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), + } + } +} + +impl MqttWrite for String { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + if self.len() > 65535 { + return Err(SerializeError::StringTooLong(self.len())); + } + + buf.put_u16(self.len() as u16); + buf.extend(self.as_bytes()); + Ok(()) + } +} +impl MqttAsyncWrite for String +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) + } +} + +impl WireLength for String { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} + +impl MqttRead for Bytes { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.len() < 2 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); + } + let len = buf.get_u16() as usize; + + if len > buf.len() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + Ok(buf.split_to(len)) + } +} +impl MqttAsyncRead for Bytes +where + S: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let size = stream.read_u16().await? as usize; + // let mut data = BytesMut::with_capacity(size); + let mut data = Vec::with_capacity(size); + let read_bytes = stream.read_exact(&mut data).await?; + assert_eq!(size, read_bytes); + Ok((data.into(), 2 + size)) + } +} +impl MqttWrite for Bytes { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.len() as u16); + buf.extend(self); + + Ok(()) + } +} +impl MqttAsyncWrite for Bytes +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_ref()).await?; + Ok(2 + self.len()) + } +} + +impl WireLength for Bytes { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} + +impl MqttRead for Vec { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.len() < 2 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); + } + let len = buf.get_u16() as usize; + + if len > buf.len() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + Ok(buf.split_to(len).into()) + } +} +impl MqttWrite for Vec { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.len() as u16); + buf.extend(self); + + Ok(()) + } +} +impl MqttAsyncWrite for Vec +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self).await?; + Ok(2 + self.len()) + } +} +impl WireLength for Vec { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} +impl MqttAsyncRead for Vec +where + S: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let size = stream.read_u16().await? as usize; + // let mut data = BytesMut::with_capacity(size); + let mut data = vec![0u8; size]; + let read_bytes = stream.read_exact(&mut data).await?; + assert_eq!(size, read_bytes); + Ok((data, 2 + size)) + } +} + +impl MqttRead for bool { + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + match buf.get_u8() { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(crate::packets::error::DeserializeError::MalformedPacket), + } + } +} +impl MqttAsyncRead for bool +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + match buf.read_u8().await? { + 0 => Ok((false, 1)), + 1 => Ok((true, 1)), + _ => Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)), + } + } +} +impl MqttWrite for bool { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + if *self { + buf.put_u8(1); + Ok(()) + } else { + buf.put_u8(0); + Ok(()) + } + } +} +impl MqttAsyncWrite for bool +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + if *self { + stream.write_all(&[1]).await?; + } else { + stream.write_all(&[0]).await?; + } + Ok(1) + } +} +impl MqttRead for u8 { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + Ok(buf.get_u8()) + } +} +impl MqttAsyncRead for u8 +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u8().await?, 1)) + } +} +impl MqttAsyncWrite for u8 +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(1) + } +} + +impl MqttRead for u16 { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.len() < 2 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); + } + Ok(buf.get_u16()) + } +} +impl MqttAsyncRead for u16 +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u16().await?, 2)) + } +} +impl MqttWrite for u16 { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(*self); + Ok(()) + } +} +impl MqttAsyncWrite for u16 +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(2) + } +} + +impl MqttRead for u32 { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.len() < 4 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); + } + Ok(buf.get_u32()) + } +} +impl MqttAsyncRead for u32 +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u32().await?, 4)) + } +} +impl MqttWrite for u32 { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u32(*self); + Ok(()) + } +} +impl MqttAsyncWrite for u32 +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(4) + } +} diff --git a/mqrstt/src/packets/mqtt_traits.rs b/mqrstt/src/packets/mqtt_traits.rs deleted file mode 100644 index 8f6122f..0000000 --- a/mqrstt/src/packets/mqtt_traits.rs +++ /dev/null @@ -1,36 +0,0 @@ -use bytes::{Bytes, BytesMut}; - -use super::error::{DeserializeError, SerializeError}; - -pub trait VariableHeaderRead: Sized { - fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; -} - -pub trait VariableHeaderWrite: Sized { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError>; -} - -pub trait WireLength { - fn wire_len(&self) -> usize; -} - -pub trait MqttRead: Sized { - fn read(buf: &mut Bytes) -> Result; -} - -pub trait MqttWrite: Sized { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError>; -} - -impl MqttWrite for &T -where - T: MqttWrite, -{ - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - ::write(self, buf) - } -} - -pub trait PacketValidation: Sized { - fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; -} diff --git a/mqrstt/src/packets/primitive/fixed_header.rs b/mqrstt/src/packets/primitive/fixed_header.rs new file mode 100644 index 0000000..4f7c626 --- /dev/null +++ b/mqrstt/src/packets/primitive/fixed_header.rs @@ -0,0 +1,60 @@ +use core::slice::Iter; + +use tokio::io::AsyncReadExt; + +use crate::packets::{ + error::{DeserializeError, ReadBytes}, + PacketType, +}; + +use super::read_fixed_header_rem_len; + +/// 2.1.1 Fixed Header +/// +/// The fixed header indicates the pakcet type in the first four bits [7 - 4] and for some packets it also contains some flags in the second four bits [3 - 0]. +/// The remaining length encodes the length of the variable header and the payload. +/// +/// | Bit | 7 - 4 | 3 - 0 | +/// |----------|----------------------------|----------------------------| +/// | byte 1 | MQTT Control Packet Type | Flags for Packet type | +/// | | | | +/// | byte 2+ | Remaining Length | +/// | |---------------------------------------------------------| +/// +/// [MQTT v5.0 Specification](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021) +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub(crate) struct FixedHeader { + pub packet_type: PacketType, + pub flags: u8, + pub remaining_length: usize, +} + +impl FixedHeader { + pub(crate) fn read_fixed_header(mut header: Iter) -> Result<(Self, usize), ReadBytes> { + if header.len() < 2 { + return Err(ReadBytes::InsufficientBytes(2 - header.len())); + } + + let mut header_length = 1; + let first_byte = header.next().unwrap(); + + let (packet_type, flags) = PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; + + let (remaining_length, length) = read_fixed_header_rem_len(header)?; + header_length += length; + + Ok((Self { packet_type, flags, remaining_length }, header_length)) + } + + pub(crate) async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> + where + S: tokio::io::AsyncRead + Unpin, + { + let first_byte = stream.read_u8().await?; + + let (packet_type, flags) = PacketType::from_first_byte(first_byte)?; + + let (remaining_length, length) = super::async_read_fixed_header_rem_len(stream).await?; + Ok((Self { packet_type, flags, remaining_length }, 1 + length)) + } +} diff --git a/mqrstt/src/packets/primitive/mod.rs b/mqrstt/src/packets/primitive/mod.rs new file mode 100644 index 0000000..4f9d756 --- /dev/null +++ b/mqrstt/src/packets/primitive/mod.rs @@ -0,0 +1,14 @@ +mod fixed_header; +pub(crate) use fixed_header::FixedHeader; + +mod protocol_version; +pub use protocol_version::ProtocolVersion; + +mod property_type; +pub(crate) use property_type::PropertyType; + +mod variable_integer; +pub(crate) use variable_integer::*; + +mod qos; +pub use qos::QoS; diff --git a/mqrstt/src/packets/primitive/property_type.rs b/mqrstt/src/packets/primitive/property_type.rs new file mode 100644 index 0000000..a1ecf70 --- /dev/null +++ b/mqrstt/src/packets/primitive/property_type.rs @@ -0,0 +1,156 @@ +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::{ + error::{DeserializeError, ReadError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +impl TryFrom for PropertyType { + type Error = DeserializeError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::PayloadFormatIndicator), + 2 => Ok(Self::MessageExpiryInterval), + 3 => Ok(Self::ContentType), + 8 => Ok(Self::ResponseTopic), + 9 => Ok(Self::CorrelationData), + 11 => Ok(Self::SubscriptionIdentifier), + 17 => Ok(Self::SessionExpiryInterval), + 18 => Ok(Self::AssignedClientIdentifier), + 19 => Ok(Self::ServerKeepAlive), + 21 => Ok(Self::AuthenticationMethod), + 22 => Ok(Self::AuthenticationData), + 23 => Ok(Self::RequestProblemInformation), + 24 => Ok(Self::WillDelayInterval), + 25 => Ok(Self::RequestResponseInformation), + 26 => Ok(Self::ResponseInformation), + 28 => Ok(Self::ServerReference), + 31 => Ok(Self::ReasonString), + 33 => Ok(Self::ReceiveMaximum), + 34 => Ok(Self::TopicAliasMaximum), + 35 => Ok(Self::TopicAlias), + 36 => Ok(Self::MaximumQos), + 37 => Ok(Self::RetainAvailable), + 38 => Ok(Self::UserProperty), + 39 => Ok(Self::MaximumPacketSize), + 40 => Ok(Self::WildcardSubscriptionAvailable), + 41 => Ok(Self::SubscriptionIdentifierAvailable), + 42 => Ok(Self::SharedSubscriptionAvailable), + t => Err(DeserializeError::UnknownProperty(t)), + } + } +} + +impl From<&PropertyType> for u8 { + fn from(value: &PropertyType) -> Self { + match value { + PropertyType::PayloadFormatIndicator => 1, + PropertyType::MessageExpiryInterval => 2, + PropertyType::ContentType => 3, + PropertyType::ResponseTopic => 8, + PropertyType::CorrelationData => 9, + PropertyType::SubscriptionIdentifier => 11, + PropertyType::SessionExpiryInterval => 17, + PropertyType::AssignedClientIdentifier => 18, + PropertyType::ServerKeepAlive => 19, + PropertyType::AuthenticationMethod => 21, + PropertyType::AuthenticationData => 22, + PropertyType::RequestProblemInformation => 23, + PropertyType::WillDelayInterval => 24, + PropertyType::RequestResponseInformation => 25, + PropertyType::ResponseInformation => 26, + PropertyType::ServerReference => 28, + PropertyType::ReasonString => 31, + PropertyType::ReceiveMaximum => 33, + PropertyType::TopicAliasMaximum => 34, + PropertyType::TopicAlias => 35, + PropertyType::MaximumQos => 36, + PropertyType::RetainAvailable => 37, + PropertyType::UserProperty => 38, + PropertyType::MaximumPacketSize => 39, + PropertyType::WildcardSubscriptionAvailable => 40, + PropertyType::SubscriptionIdentifierAvailable => 41, + PropertyType::SharedSubscriptionAvailable => 42, + } + } +} + +impl From for u8 { + fn from(value: PropertyType) -> Self { + value as u8 + } +} + +impl MqttRead for PropertyType { + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + buf.get_u8().try_into() + } +} + +impl MqttAsyncRead for PropertyType +where + S: tokio::io::AsyncRead + std::marker::Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + match stream.read_u8().await { + Ok(t) => Ok((t.try_into()?, 1)), + Err(e) => Err(ReadError::IoError(e)), + } + } +} + +impl MqttWrite for PropertyType { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u8(self.into()); + Ok(()) + } +} + +impl MqttAsyncWrite for PropertyType +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let buf: [u8; 1] = [u8::from(self)]; + stream.write_all(&buf).await?; + Ok(1) + } +} diff --git a/mqrstt/src/packets/primitive/protocol_version.rs b/mqrstt/src/packets/primitive/protocol_version.rs new file mode 100644 index 0000000..bed9a18 --- /dev/null +++ b/mqrstt/src/packets/primitive/protocol_version.rs @@ -0,0 +1,66 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use tokio::io::AsyncReadExt; + +use crate::packets::{ + error::{DeserializeError, ReadError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, +}; + +/// Protocol version of the MQTT connection +/// +/// This client only supports MQTT v5.0. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum ProtocolVersion { + V5, +} + +impl MqttWrite for ProtocolVersion { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u8(5u8); + Ok(()) + } +} + +impl MqttAsyncWrite for ProtocolVersion +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use tokio::io::AsyncWriteExt; + async move { + stream.write_u8(5).await?; + Ok(1) + } + } +} + +impl MqttRead for ProtocolVersion { + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientDataForProtocolVersion); + } + + match buf.get_u8() { + 3 => Err(DeserializeError::UnsupportedProtocolVersion), + 4 => Err(DeserializeError::UnsupportedProtocolVersion), + 5 => Ok(ProtocolVersion::V5), + _ => Err(DeserializeError::UnknownProtocolVersion), + } + } +} + +impl MqttAsyncRead for ProtocolVersion +where + S: tokio::io::AsyncRead + std::marker::Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + match stream.read_u8().await { + Ok(5) => Ok((ProtocolVersion::V5, 1)), + Ok(4) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), + Ok(3) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), + Ok(_) => Err(ReadError::DeserializeError(DeserializeError::UnknownProtocolVersion)), + Err(e) => Err(ReadError::IoError(e)), + } + } +} diff --git a/mqrstt/src/packets/primitive/qos.rs b/mqrstt/src/packets/primitive/qos.rs new file mode 100644 index 0000000..06be58e --- /dev/null +++ b/mqrstt/src/packets/primitive/qos.rs @@ -0,0 +1,84 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::{ + error::{DeserializeError, ReadError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, +}; + +use tokio::io::AsyncWriteExt; + +/// Quality of service +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum QoS { + #[default] + AtMostOnce = 0, + AtLeastOnce = 1, + ExactlyOnce = 2, +} +impl QoS { + pub fn from_u8(value: u8) -> Result { + match value { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + _ => Err(DeserializeError::UnknownQoS(value)), + } + } + pub fn into_u8(self) -> u8 { + match self { + QoS::AtMostOnce => 0, + QoS::AtLeastOnce => 1, + QoS::ExactlyOnce => 2, + } + } +} + +impl MqttRead for QoS { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + match buf.get_u8() { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + q => Err(DeserializeError::UnknownQoS(q)), + } + } +} + +impl MqttAsyncRead for QoS +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + match buf.read_u8().await { + Ok(0) => Ok((QoS::AtMostOnce, 1)), + Ok(1) => Ok((QoS::AtLeastOnce, 1)), + Ok(2) => Ok((QoS::ExactlyOnce, 1)), + Ok(q) => Err(ReadError::DeserializeError(DeserializeError::UnknownQoS(q))), + Err(e) => Err(ReadError::IoError(e)), + } + } +} + +impl MqttWrite for QoS { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + let val = self.into_u8(); + buf.put_u8(val); + Ok(()) + } +} +impl MqttAsyncWrite for QoS +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + let buf: [u8; 1] = [self.into_u8()]; + stream.write_all(&buf).await?; + Ok(1) + } +} diff --git a/mqrstt/src/packets/primitive/variable_integer.rs b/mqrstt/src/packets/primitive/variable_integer.rs new file mode 100644 index 0000000..1841ead --- /dev/null +++ b/mqrstt/src/packets/primitive/variable_integer.rs @@ -0,0 +1,248 @@ +use crate::packets::error::WriteError; +use crate::packets::error::{DeserializeError, ReadBytes, ReadError, SerializeError}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use core::slice::Iter; +use std::future::Future; + +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; + +pub(crate) fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + if let Some(byte) = buf.next() { + length += 1; + integer += (*byte as usize & 0x7f) << (7 * i); + + if (*byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } else { + return Err(ReadBytes::InsufficientBytes(1)); + } + } + Err(ReadBytes::Err(DeserializeError::MalformedPacket)) +} + +pub(crate) async fn async_read_fixed_header_rem_len(stream: &mut S) -> Result<(usize, usize), ReadError> +where + S: tokio::io::AsyncRead + Unpin, +{ + let mut integer = 0; + let mut length = 0; + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) +} +pub(crate) trait VariableInteger: Sized { + fn variable_integer_len(&self) -> usize; + fn write_variable_integer(&self, buf: &mut BytesMut) -> Result; + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError>; + fn read_async_variable_integer(stream: &mut S) -> impl Future>; + fn write_async_variable_integer(&self, stream: &mut S) -> impl Future>; +} + +impl VariableInteger for usize { + fn variable_integer_len(&self) -> usize { + if *self >= 2_097_152 { + 4 + } else if *self >= 16_384 { + 3 + } else if *self >= 128 { + 2 + } else { + 1 + } + } + + fn write_variable_integer(&self, buf: &mut BytesMut) -> Result { + if *self > 268_435_455 { + return Err(SerializeError::VariableIntegerOverflow(*self)); + } + + let mut write = *self; + + for i in 0..4 { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + buf.put_u8(byte); + if write == 0 { + return Ok(i + 1); + } + } + Err(SerializeError::VariableIntegerOverflow(*self)) + } + + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + if buf.is_empty() { + return Err(DeserializeError::MalformedPacket); + } + length += 1; + let byte = buf.get_u8(); + + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(DeserializeError::MalformedPacket) + } + + async fn read_async_variable_integer(stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; + + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) + } + + async fn write_async_variable_integer(&self, stream: &mut S) -> Result { + let mut buf = [0u8; 4]; + + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self))); + } + + let mut write = *self; + let mut length = 1; + + for (i, item) in buf.iter_mut().enumerate() { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + *item = byte; + if write == 0 { + length = i + 1; + break; + } + } + stream.write_all(&buf[0..length]).await?; + Ok(length) + } +} + +impl VariableInteger for u32 { + fn variable_integer_len(&self) -> usize { + if *self >= 2_097_152 { + 4 + } else if *self >= 16_384 { + 3 + } else if *self >= 128 { + 2 + } else { + 1 + } + } + + fn write_variable_integer(&self, buf: &mut BytesMut) -> Result { + if *self > 268_435_455 { + return Err(SerializeError::VariableIntegerOverflow(*self as usize)); + } + + let mut write = *self; + + for i in 0..4 { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + buf.put_u8(byte); + if write == 0 { + return Ok(i + 1); + } + } + Err(SerializeError::VariableIntegerOverflow(*self as usize)) + } + + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + if buf.is_empty() { + return Err(DeserializeError::MalformedPacket); + } + length += 1; + let byte = buf.get_u8(); + + integer += (byte as u32 & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(DeserializeError::MalformedPacket) + } + + async fn read_async_variable_integer(stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; + + integer += (byte as u32 & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) + } + + async fn write_async_variable_integer(&self, stream: &mut S) -> Result { + let mut buf = [0u8; 4]; + + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); + } + + let mut write = *self; + let mut length = 1; + + for (i, item) in buf.iter_mut().enumerate() { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + *item = byte; + if write == 0 { + length = i + 1; + break; + } + } + stream.write_all(&buf[0..length]).await?; + Ok(length) + } +} diff --git a/mqrstt/src/packets/puback.rs b/mqrstt/src/packets/puback/mod.rs similarity index 68% rename from mqrstt/src/packets/puback.rs rename to mqrstt/src/packets/puback/mod.rs index 4b0970c..55f5886 100644 --- a/mqrstt/src/packets/puback.rs +++ b/mqrstt/src/packets/puback/mod.rs @@ -1,13 +1,16 @@ -use bytes::BufMut; +mod reason_code; +pub use reason_code::PubAckReasonCode; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::PubAckReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, VariableInteger, }; +use bytes::BufMut; +use tokio::io::AsyncReadExt; +/// The PUBACK Packet is the response to a PUBLISH Packet with QoS 1. +/// Both the server and client can send a PUBACK packet. #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct PubAck { pub packet_identifier: u16, @@ -15,7 +18,44 @@ pub struct PubAck { pub properties: PubAckProperties, } -impl VariableHeaderRead for PubAck { +impl PacketAsyncRead for PubAck +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let packet_identifier = stream.read_u16().await?; + if remaining_length == 2 { + Ok(( + Self { + packet_identifier, + reason_code: PubAckReasonCode::Success, + properties: PubAckProperties::default(), + }, + 2, + )) + } else if remaining_length < 4 { + return Err(crate::packets::error::ReadError::DeserializeError(DeserializeError::InsufficientData( + std::any::type_name::(), + remaining_length, + 4, + ))); + } else { + let (reason_code, reason_code_read_bytes) = PubAckReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubAckProperties::async_read(stream).await?; + + Ok(( + Self { + packet_identifier, + reason_code, + properties, + }, + 2 + reason_code_read_bytes + properties_read_bytes, + )) + } + } +} + +impl PacketRead for PubAck { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { @@ -27,7 +67,7 @@ impl VariableHeaderRead for PubAck { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData("PubAck".to_string(), buf.len(), 4)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -42,7 +82,7 @@ impl VariableHeaderRead for PubAck { } } -impl VariableHeaderWrite for PubAck { +impl PacketWrite for PubAck { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -58,6 +98,29 @@ impl VariableHeaderWrite for PubAck { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_written_bytes = 2; + self.packet_identifier.async_write(stream).await?; + + if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_written_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_written_bytes += self.reason_code.async_write(stream).await?; + } else { + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} + impl WireLength for PubAck { fn wire_len(&self) -> usize { if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { @@ -69,38 +132,33 @@ impl WireLength for PubAck { } else { let prop_len = self.properties.wire_len(); // pkid, reason code, length of the length of properties and lenght of properties - 3 + variable_integer_len(prop_len) + prop_len + 3 + prop_len.variable_integer_len() + prop_len } } } -#[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -pub struct PubAckProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} - -impl PubAckProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} +crate::packets::macros::define_properties!( + /// PubAck Properties + PubAckProperties, + ReasonString, + UserProperty +); impl MqttRead for PubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; + let (len, _) = VariableInteger::read_variable_integer(buf)?; if len == 0 { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubAckProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties = PubAckProperties::default(); loop { - match PropertyType::from_u8(u8::read(buf)?)? { + match PropertyType::try_from(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); @@ -122,7 +180,7 @@ impl MqttWrite for PubAckProperties { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { let len = self.wire_len(); - write_variable_integer(buf, len)?; + len.write_variable_integer(buf)?; if let Some(reason_string) = &self.reason_string { PropertyType::ReasonString.write(buf)?; @@ -138,27 +196,12 @@ impl MqttWrite for PubAckProperties { } } -impl WireLength for PubAckProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - - len - } -} - #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, puback::{PubAck, PubAckProperties}, - reason_codes::PubAckReasonCode, - write_variable_integer, PropertyType, + PropertyType, PubAckReasonCode, VariableInteger, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -221,7 +264,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -248,7 +291,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let properties = PubAckProperties::read(&mut buf.clone().into()).unwrap(); diff --git a/mqrstt/src/packets/puback/reason_code.rs b/mqrstt/src/packets/puback/reason_code.rs new file mode 100644 index 0000000..ea78a45 --- /dev/null +++ b/mqrstt/src/packets/puback/reason_code.rs @@ -0,0 +1,12 @@ +crate::packets::macros::reason_code!( + PubAckReasonCode, + Success, + NoMatchingSubscribers, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicNameInvalid, + PacketIdentifierInUse, + QuotaExceeded, + PayloadFormatInvalid +); diff --git a/mqrstt/src/packets/pubcomp.rs b/mqrstt/src/packets/pubcomp/mod.rs similarity index 62% rename from mqrstt/src/packets/pubcomp.rs rename to mqrstt/src/packets/pubcomp/mod.rs index b814da2..dce9b43 100644 --- a/mqrstt/src/packets/pubcomp.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -1,13 +1,22 @@ -use bytes::BufMut; +mod reason_code; +pub use reason_code::PubCompReasonCode; + +mod properties; +pub use properties::PubCompProperties; use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::PubCompReasonCode, - write_variable_integer, PacketType, PropertyType, + error::{DeserializeError, ReadError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, }; +use bytes::BufMut; +use tokio::io::AsyncReadExt; +/// The PUBCOMP Packet is the response to a PUBLISH Packet with QoS 2. +/// It is the fourth and final packet of the QoS 2 protocol exchange. +/// The user of the client application does not have to send this packet, it is handled internally by the client. +/// +/// Both the client and server can send this packet. #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct PubComp { pub packet_identifier: u16, @@ -25,7 +34,7 @@ impl PubComp { } } -impl VariableHeaderRead for PubComp { +impl PacketRead for PubComp { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { @@ -37,7 +46,7 @@ impl VariableHeaderRead for PubComp { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData("PubComp".to_string(), buf.len(), 4)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -52,12 +61,56 @@ impl VariableHeaderRead for PubComp { } } -impl VariableHeaderWrite for PubComp { +impl PacketAsyncRead for PubComp +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let packet_identifier = stream.read_u16().await?; + if remaining_length == 2 { + return Ok(( + Self { + packet_identifier, + reason_code: PubCompReasonCode::Success, + properties: PubCompProperties::default(), + }, + 2, + )); + } + // Requires u16, u8 and at least 1 byte of variable integer prop length so at least 4 bytes + else if remaining_length < 4 { + return Err(ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4))); + } + + let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; + + let total_read_bytes = 2 + reason_code_read_bytes + properties_read_bytes; + + if total_read_bytes != remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { + read: total_read_bytes, + remaining_length, + })); + } + + Ok(( + Self { + packet_identifier, + reason_code, + properties, + }, + total_read_bytes, + )) + } +} + +impl PacketWrite for PubComp { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - // nothing here + return Ok(()); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { self.reason_code.write(buf)?; } else { @@ -68,103 +121,46 @@ impl VariableHeaderWrite for PubComp { } } -impl WireLength for PubComp { - fn wire_len(&self) -> usize { +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubComp +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + use crate::packets::mqtt_trait::MqttAsyncWrite; + let mut total_written_bytes = 2; + self.packet_identifier.async_write(stream).await?; + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - 2 + return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - 3 + total_written_bytes += self.reason_code.async_write(stream).await?; } else { - 2 + 1 + self.properties.wire_len() + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; } + Ok(total_written_bytes) } } -#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] -pub struct PubCompProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} - -impl PubCompProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} - -impl MqttRead for PubCompProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - if len == 0 { - return Ok(Self::default()); - } - if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubCompProperties".to_string(), buf.len(), len)); - } - - let mut properties = PubCompProperties::default(); - - loop { - match PropertyType::from_u8(u8::read(buf)?)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(buf)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubComp)), - } - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for PubCompProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let len = self.wire_len(); - - write_variable_integer(buf, len)?; - - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)? - } - - Ok(()) - } -} - -impl WireLength for PubCompProperties { +impl WireLength for PubComp { fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + 2 + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + 3 + } else { + let prop_wire_len = self.properties.wire_len(); + 2 + 1 + prop_wire_len.variable_integer_len() + prop_wire_len } - - len } } #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubcomp::{PubComp, PubCompProperties}, - reason_codes::PubCompReasonCode, - write_variable_integer, PropertyType, + PropertyType, PubCompReasonCode, VariableInteger, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -227,7 +223,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -252,7 +248,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let properties = PubCompProperties::read(&mut buf.clone().into()).unwrap(); diff --git a/mqrstt/src/packets/pubcomp/properties.rs b/mqrstt/src/packets/pubcomp/properties.rs new file mode 100644 index 0000000..ff5fb49 --- /dev/null +++ b/mqrstt/src/packets/pubcomp/properties.rs @@ -0,0 +1,64 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!( + /// PubComp Properties + PubCompProperties, + ReasonString, + UserProperty +); + +impl MqttRead for PubCompProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + if len == 0 { + return Ok(Self::default()); + } + if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties = PubCompProperties::default(); + + loop { + match PropertyType::try_from(u8::read(buf)?)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(buf)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubComp)), + } + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for PubCompProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + let len = self.wire_len(); + + len.write_variable_integer(buf)?; + + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)? + } + + Ok(()) + } +} diff --git a/mqrstt/src/packets/pubcomp/reason_code.rs b/mqrstt/src/packets/pubcomp/reason_code.rs new file mode 100644 index 0000000..a5c531a --- /dev/null +++ b/mqrstt/src/packets/pubcomp/reason_code.rs @@ -0,0 +1,5 @@ +crate::packets::macros::reason_code!( + PubCompReasonCode, + Success, + PacketIdentifierNotFound +); diff --git a/mqrstt/src/packets/publish.rs b/mqrstt/src/packets/publish.rs deleted file mode 100644 index 6f3baeb..0000000 --- a/mqrstt/src/packets/publish.rs +++ /dev/null @@ -1,375 +0,0 @@ -use bytes::{BufMut, Bytes}; - -use crate::error::PacketValidationError; -use crate::util::constants::MAXIMUM_TOPIC_SIZE; - -use super::mqtt_traits::{MqttRead, MqttWrite, PacketValidation, VariableHeaderRead, VariableHeaderWrite, WireLength}; -use super::{ - error::{DeserializeError, SerializeError}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, -}; - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct Publish { - /// 3.3.1.1 dup - pub dup: bool, - /// 3.3.1.2 QoS - pub qos: QoS, - /// 3.3.1.3 retain - pub retain: bool, - - /// 3.3.2.1 Topic Name - /// The Topic Name identifies the information channel to which Payload data is published. - pub topic: Box, - - /// 3.3.2.2 Packet Identifier - /// The Packet Identifier field is only present in PUBLISH packets where the QoS level is 1 or 2. Section 2.2.1 provides more information about Packet Identifiers. - pub packet_identifier: Option, - - /// 3.3.2.3 PUBLISH Properties - pub publish_properties: PublishProperties, - - /// 3.3.3 PUBLISH Payload - pub payload: Bytes, -} - -impl Publish { - pub fn new>(qos: QoS, retain: bool, topic: S, packet_identifier: Option, publish_properties: PublishProperties, payload: Bytes) -> Self { - Self { - dup: false, - qos, - retain, - topic: topic.as_ref().into(), - packet_identifier, - publish_properties, - payload, - } - } - - pub fn payload_to_vec(&self) -> Vec { - self.payload.to_vec() - } -} - -impl VariableHeaderRead for Publish { - fn read(flags: u8, _: usize, mut buf: bytes::Bytes) -> Result { - let dup = flags & 0b1000 != 0; - let qos = QoS::from_u8((flags & 0b110) >> 1)?; - let retain = flags & 0b1 != 0; - - let topic = Box::::read(&mut buf)?; - let mut packet_identifier = None; - if qos != QoS::AtMostOnce { - packet_identifier = Some(u16::read(&mut buf)?); - } - - let publish_properties = PublishProperties::read(&mut buf)?; - - Ok(Self { - dup, - qos, - retain, - topic, - packet_identifier, - publish_properties, - payload: buf, - }) - } -} - -impl VariableHeaderWrite for Publish { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - self.topic.write(buf)?; - - if let Some(pkid) = self.packet_identifier { - buf.put_u16(pkid); - } - - self.publish_properties.write(buf)?; - - buf.extend(&self.payload); - - Ok(()) - } -} - -impl WireLength for Publish { - fn wire_len(&self) -> usize { - let mut len = self.topic.wire_len(); - if self.packet_identifier.is_some() { - len += 2; - } - - let properties_len = self.publish_properties.wire_len(); - - len += variable_integer_len(properties_len); - len += properties_len; - len += self.payload.len(); - len - } -} - -impl PacketValidation for Publish { - fn validate(&self, max_packet_size: usize) -> Result<(), PacketValidationError> { - use PacketValidationError::*; - if self.wire_len() > max_packet_size { - Err(MaxPacketSize(self.wire_len())) - } else if self.topic.len() > MAXIMUM_TOPIC_SIZE { - Err(TopicSize(self.topic.len())) - } else { - Ok(()) - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct PublishProperties { - /// 3.3.2.3.2 Payload Format Indicator - /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. - pub payload_format_indicator: Option, - - /// 3.3.2.3.3 Message Expiry Interval - /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. - pub message_expiry_interval: Option, - - /// 3.3.2.3.4 Topic Alias - /// 35 (0x23) Byte, Identifier of the Topic Alias. - pub topic_alias: Option, - - /// 3.3.2.3.5 Response Topic - /// 8 (0x08) Byte, Identifier of the Response Topic. - pub response_topic: Option>, - - /// 3.3.2.3.6 Correlation Data - /// 9 (0x09) Byte, Identifier of the Correlation Data. - pub correlation_data: Option, - - /// 3.3.2.3.8 Subscription Identifier - /// 11 (0x0B), Identifier of the Subscription Identifier. - pub subscription_identifier: Vec, - - /// 3.3.2.3.7 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, - - /// 3.3.2.3.9 Content Type - /// 3 (0x03) Identifier of the Content Type - pub content_type: Option>, -} - -impl MqttRead for PublishProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf).map_err(DeserializeError::from)?; - - if len == 0 { - return Ok(Self::default()); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("PublishProperties".to_string(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - let mut properties = Self::default(); - - loop { - match PropertyType::from_u8(u8::read(&mut property_data)?)? { - PropertyType::PayloadFormatIndicator => { - if properties.payload_format_indicator.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); - } - properties.payload_format_indicator = Some(u8::read(&mut property_data)?); - } - PropertyType::MessageExpiryInterval => { - if properties.message_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); - } - properties.message_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::TopicAlias => { - if properties.topic_alias.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAlias)); - } - properties.topic_alias = Some(u16::read(&mut property_data)?); - } - PropertyType::ResponseTopic => { - if properties.response_topic.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); - } - properties.response_topic = Some(Box::::read(&mut property_data)?); - } - PropertyType::CorrelationData => { - if properties.correlation_data.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); - } - properties.correlation_data = Some(Bytes::read(&mut property_data)?); - } - PropertyType::SubscriptionIdentifier => { - properties.subscription_identifier.push(read_variable_integer(&mut property_data)?.0); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::ContentType => { - if properties.content_type.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); - } - properties.content_type = Some(Box::::read(&mut property_data)?); - } - t => return Err(DeserializeError::UnexpectedProperty(t, PacketType::Publish)), - } - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for PublishProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(payload_format_indicator) = self.payload_format_indicator { - buf.put_u8(PropertyType::PayloadFormatIndicator.to_u8()); - buf.put_u8(payload_format_indicator); - } - if let Some(message_expiry_interval) = self.message_expiry_interval { - buf.put_u8(PropertyType::MessageExpiryInterval.to_u8()); - buf.put_u32(message_expiry_interval); - } - if let Some(topic_alias) = self.topic_alias { - buf.put_u8(PropertyType::TopicAlias.to_u8()); - buf.put_u16(topic_alias); - } - if let Some(response_topic) = &self.response_topic { - buf.put_u8(PropertyType::ResponseTopic.to_u8()); - response_topic.as_ref().write(buf)?; - } - if let Some(correlation_data) = &self.correlation_data { - buf.put_u8(PropertyType::CorrelationData.to_u8()); - correlation_data.write(buf)?; - } - for sub_id in &self.subscription_identifier { - buf.put_u8(PropertyType::SubscriptionIdentifier.to_u8()); - write_variable_integer(buf, *sub_id)?; - } - for (key, val) in &self.user_properties { - buf.put_u8(PropertyType::UserProperty.to_u8()); - key.write(buf)?; - val.write(buf)?; - } - if let Some(content_type) = &self.content_type { - buf.put_u8(PropertyType::ContentType.to_u8()); - content_type.write(buf)?; - } - - Ok(()) - } -} - -impl WireLength for PublishProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - - if self.payload_format_indicator.is_some() { - len += 2; - } - if self.message_expiry_interval.is_some() { - len += 5; - } - if self.topic_alias.is_some() { - len += 3; - } - if let Some(response_topic) = &self.response_topic { - len += 1 + response_topic.wire_len(); - } - if let Some(correlation_data) = &self.correlation_data { - len += 1 + correlation_data.wire_len(); - } - for sub_id in &self.subscription_identifier { - len += 1 + variable_integer_len(*sub_id); - } - for (key, val) in &self.user_properties { - len += 1 + key.wire_len() + val.wire_len(); - } - if let Some(content_type) = &self.content_type { - len += 1 + content_type.wire_len(); - } - - len - } -} - -#[cfg(test)] -mod tests { - use bytes::{BufMut, BytesMut}; - - use crate::packets::{ - mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}, - write_variable_integer, - }; - - use super::Publish; - - #[test] - fn test_read_write_properties() { - let first_byte = 0b0011_0100; - - let mut properties = [1, 0, 2].to_vec(); - properties.extend(4_294_967_295u32.to_be_bytes()); - properties.push(35); - properties.extend(3456u16.to_be_bytes()); - properties.push(8); - let resp_topic = "hellogoodbye"; - properties.extend((resp_topic.len() as u16).to_be_bytes()); - properties.extend(resp_topic.as_bytes()); - - let mut buf_one = BytesMut::from( - &[ - 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' - ][..], - ); - buf_one.put_u16(10); - write_variable_integer(&mut buf_one, properties.len()).unwrap(); - buf_one.extend(properties); - buf_one.extend( - [ - 0x01, // Payload - 0x02, 0xDE, 0xAD, 0xBE, - ] - .to_vec(), - ); - - let rem_len = buf_one.len(); - - let buf = BytesMut::from(&buf_one[..]); - - let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); - - let mut result_buf = BytesMut::new(); - p.write(&mut result_buf).unwrap(); - - dbg!(p.clone()); - - assert_eq!(buf_one.to_vec(), result_buf.to_vec()) - } - - #[test] - fn test_read_write() { - let first_byte = 0b0011_0000; - let buf_one = &[ - 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' - 0x00, 0x01, 0x02, // payload - 0xDE, 0xAD, 0xBE, - ]; - let rem_len = buf_one.len(); - - let buf = BytesMut::from(&buf_one[..]); - - let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); - - let mut result_buf = BytesMut::new(); - p.write(&mut result_buf).unwrap(); - - assert_eq!(buf_one.to_vec(), result_buf.to_vec()) - } -} diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs new file mode 100644 index 0000000..004dafc --- /dev/null +++ b/mqrstt/src/packets/publish/mod.rs @@ -0,0 +1,281 @@ +mod properties; +pub use properties::PublishProperties; + +use tokio::io::AsyncReadExt; + +use bytes::BufMut; + +use crate::error::PacketValidationError; +use crate::packets::error::ReadError; +use crate::util::constants::MAXIMUM_TOPIC_SIZE; + +use super::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; +use super::VariableInteger; +use super::{ + error::{DeserializeError, SerializeError}, + QoS, +}; + +/// The PUBLISH Packet is used to send data from either side of the connection. +/// This packet is handed to the [`crate::AsyncEventHandler`] to be handled by the user. +/// +/// The following flow is determined by the QoS level used in PUBLISH Packet. +/// QoS 0: Send and forget, no deliviery garantee. +/// QoS 1: Send and acknowledgement, uised to ensure that the packet is delivered at least once. +/// QoS 2: Send and 2-step acknowledgement, used to ensure that the packet is delivered only once. +/// The packet can be send using for example [`crate::MqttClient::publish`] or [`crate::MqttClient::publish_with_properties`]. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Publish { + /// 3.3.1.1 dup + pub dup: bool, + /// 3.3.1.2 QoS + pub qos: QoS, + /// 3.3.1.3 retain + pub retain: bool, + + /// 3.3.2.1 Topic Name + /// The Topic Name identifies the information channel to which Payload data is published. + pub topic: Box, + + /// 3.3.2.2 Packet Identifier + /// The Packet Identifier field is only present in PUBLISH packets where the QoS level is 1 or 2. Section 2.2.1 provides more information about Packet Identifiers. + pub packet_identifier: Option, + + /// 3.3.2.3 PUBLISH Properties + pub publish_properties: PublishProperties, + + /// 3.3.3 PUBLISH Payload + pub payload: Vec, +} + +impl Publish { + pub fn new, P: Into>>(qos: QoS, retain: bool, topic: S, packet_identifier: Option, publish_properties: PublishProperties, payload: P) -> Self { + Self { + dup: false, + qos, + retain, + topic: topic.as_ref().into(), + packet_identifier, + publish_properties, + payload: payload.into(), + } + } + + pub fn payload(&self) -> &Vec { + &self.payload + } +} + +impl PacketRead for Publish { + fn read(flags: u8, _: usize, mut buf: bytes::Bytes) -> Result { + let dup = flags & 0b1000 != 0; + let qos = QoS::from_u8((flags & 0b110) >> 1)?; + let retain = flags & 0b1 != 0; + + let topic = Box::::read(&mut buf)?; + let mut packet_identifier = None; + if qos != QoS::AtMostOnce { + packet_identifier = Some(u16::read(&mut buf)?); + } + + let publish_properties = PublishProperties::read(&mut buf)?; + + Ok(Self { + dup, + qos, + retain, + topic, + packet_identifier, + publish_properties, + payload: buf.to_vec(), + }) + } +} + +impl PacketAsyncRead for Publish +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let dup = flags & 0b1000 != 0; + let qos = QoS::from_u8((flags & 0b110) >> 1)?; + let retain = flags & 0b1 != 0; + + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_bytes; + let packet_identifier = if qos == QoS::AtMostOnce { + None + } else { + total_read_bytes += 2; + Some(stream.read_u16().await?) + }; + let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; + + if total_read_bytes > remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)); + } + let payload_len = remaining_length - total_read_bytes; + let mut payload = vec![0u8; payload_len]; + let payload_read_bytes = stream.read_exact(&mut payload).await?; + + assert_eq!(payload_read_bytes, payload_len); + + Ok(( + Self { + dup, + qos, + retain, + topic, + packet_identifier, + publish_properties, + payload, + }, + total_read_bytes + payload_read_bytes, + )) + } +} + +impl PacketWrite for Publish { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + self.topic.write(buf)?; + + if let Some(pkid) = self.packet_identifier { + buf.put_u16(pkid); + } + + self.publish_properties.write(buf)?; + + buf.extend(&self.payload); + + Ok(()) + } +} + +impl crate::packets::mqtt_trait::PacketAsyncWrite for Publish +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 0; + total_written_bytes += self.topic.async_write(stream).await?; + + if let Some(pkid) = self.packet_identifier { + stream.write_u16(pkid).await?; + total_written_bytes += 2; + } + total_written_bytes += self.publish_properties.async_write(stream).await?; + + stream.write_all(&self.payload).await?; + total_written_bytes += self.payload.len(); + + Ok(total_written_bytes) + } + } +} + +impl WireLength for Publish { + fn wire_len(&self) -> usize { + let mut len = self.topic.wire_len(); + if self.packet_identifier.is_some() { + len += 2; + } + + let properties_len = self.publish_properties.wire_len(); + + len += properties_len.variable_integer_len(); + len += properties_len; + len += self.payload.len(); + len + } +} + +impl PacketValidation for Publish { + fn validate(&self, max_packet_size: usize) -> Result<(), PacketValidationError> { + use PacketValidationError::*; + if self.wire_len() > max_packet_size { + Err(MaxPacketSize(self.wire_len())) + } else if self.topic.len() > MAXIMUM_TOPIC_SIZE { + Err(TopicSize(self.topic.len())) + } else { + Ok(()) + } + } +} + +#[cfg(test)] +mod tests { + use bytes::{BufMut, BytesMut}; + + use crate::packets::{ + mqtt_trait::{PacketRead, PacketWrite}, + VariableInteger, + }; + + use super::Publish; + + #[test] + fn test_read_write_properties() { + let first_byte = 0b0011_0100; + + let mut properties = [1, 0, 2].to_vec(); + properties.extend(4_294_967_295u32.to_be_bytes()); + properties.push(35); + properties.extend(3456u16.to_be_bytes()); + properties.push(8); + let resp_topic = "hellogoodbye"; + properties.extend((resp_topic.len() as u16).to_be_bytes()); + properties.extend(resp_topic.as_bytes()); + + let mut buf_one = BytesMut::from( + &[ + 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' + ][..], + ); + buf_one.put_u16(10); + properties.len().write_variable_integer(&mut buf_one).unwrap(); + buf_one.extend(properties); + buf_one.extend( + [ + 0x01, // Payload + 0x02, 0xDE, 0xAD, 0xBE, + ] + .to_vec(), + ); + + let rem_len = buf_one.len(); + + let buf = buf_one.clone(); + + let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); + + let mut result_buf = BytesMut::with_capacity(1000); + p.write(&mut result_buf).unwrap(); + + assert_eq!(buf_one.to_vec(), result_buf.to_vec()) + } + + #[test] + fn test_read_write() { + let first_byte = 0b0011_0000; + let buf_one = &[ + 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' + 0x00, 0x01, 0x02, // payload + 0xDE, 0xAD, 0xBE, + ]; + let rem_len = buf_one.len(); + + let buf = BytesMut::from(&buf_one[..]); + + let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); + + let mut result_buf = BytesMut::new(); + p.write(&mut result_buf).unwrap(); + + assert_eq!(buf_one.to_vec(), result_buf.to_vec()) + } +} diff --git a/mqrstt/src/packets/publish/properties.rs b/mqrstt/src/packets/publish/properties.rs new file mode 100644 index 0000000..cd78081 --- /dev/null +++ b/mqrstt/src/packets/publish/properties.rs @@ -0,0 +1,131 @@ +use bytes::BufMut; + +use crate::packets::VariableInteger; + +use crate::packets::mqtt_trait::{MqttRead, MqttWrite, WireLength}; +use crate::packets::{ + error::{DeserializeError, SerializeError}, + PacketType, PropertyType, +}; + +crate::packets::macros::define_properties!( + /// Publish Properties + PublishProperties, + PayloadFormatIndicator, + MessageExpiryInterval, + ContentType, + ResponseTopic, + CorrelationData, + ListSubscriptionIdentifier, + TopicAlias, + UserProperty +); + +impl MqttRead for PublishProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; + + if len == 0 { + return Ok(Self::default()); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + let mut properties = Self::default(); + + loop { + match PropertyType::try_from(u8::read(&mut property_data)?)? { + PropertyType::PayloadFormatIndicator => { + if properties.payload_format_indicator.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); + } + properties.payload_format_indicator = Some(u8::read(&mut property_data)?); + } + PropertyType::MessageExpiryInterval => { + if properties.message_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); + } + properties.message_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::TopicAlias => { + if properties.topic_alias.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAlias)); + } + properties.topic_alias = Some(u16::read(&mut property_data)?); + } + PropertyType::ResponseTopic => { + if properties.response_topic.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); + } + properties.response_topic = Some(Box::::read(&mut property_data)?); + } + PropertyType::CorrelationData => { + if properties.correlation_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); + } + properties.correlation_data = Some(Vec::::read(&mut property_data)?); + } + PropertyType::SubscriptionIdentifier => { + properties.subscription_identifiers.push(VariableInteger::read_variable_integer(&mut property_data)?.0); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::ContentType => { + if properties.content_type.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); + } + properties.content_type = Some(Box::::read(&mut property_data)?); + } + t => return Err(DeserializeError::UnexpectedProperty(t, PacketType::Publish)), + } + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for PublishProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buf.put_u8(PropertyType::PayloadFormatIndicator.into()); + buf.put_u8(payload_format_indicator); + } + if let Some(message_expiry_interval) = self.message_expiry_interval { + buf.put_u8(PropertyType::MessageExpiryInterval.into()); + buf.put_u32(message_expiry_interval); + } + if let Some(topic_alias) = self.topic_alias { + buf.put_u8(PropertyType::TopicAlias.into()); + buf.put_u16(topic_alias); + } + if let Some(response_topic) = &self.response_topic { + buf.put_u8(PropertyType::ResponseTopic.into()); + response_topic.as_ref().write(buf)?; + } + if let Some(correlation_data) = &self.correlation_data { + buf.put_u8(PropertyType::CorrelationData.into()); + correlation_data.write(buf)?; + } + for sub_id in &self.subscription_identifiers { + buf.put_u8(PropertyType::SubscriptionIdentifier.into()); + sub_id.write_variable_integer(buf)?; + } + for (key, val) in &self.user_properties { + buf.put_u8(PropertyType::UserProperty.into()); + key.write(buf)?; + val.write(buf)?; + } + if let Some(content_type) = &self.content_type { + buf.put_u8(PropertyType::ContentType.into()); + content_type.write(buf)?; + } + + Ok(()) + } +} diff --git a/mqrstt/src/packets/pubrec.rs b/mqrstt/src/packets/pubrec/mod.rs similarity index 65% rename from mqrstt/src/packets/pubrec.rs rename to mqrstt/src/packets/pubrec/mod.rs index 15289c8..075ee04 100644 --- a/mqrstt/src/packets/pubrec.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -1,13 +1,23 @@ +mod properties; +pub use properties::PubRecProperties; + +mod reason_code; +pub use reason_code::PubRecReasonCode; + use bytes::BufMut; +use tokio::io::AsyncReadExt; + use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::PubRecReasonCode, - write_variable_integer, PacketType, PropertyType, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + PacketAsyncRead, VariableInteger, }; +/// The [`PubRec`] (Publish Received) packet is part of the acknowledgment flow for a [`crate::packets::Publish`] with QoS 2. +/// +/// It means that the Publish has been received, the flow will continue with the [`crate::packets::pubrel::PubRel`] +/// packet and then the [`crate::packets::pubcomp::PubComp`] packet. #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct PubRec { pub packet_identifier: u16, @@ -24,7 +34,7 @@ impl PubRec { } } -impl VariableHeaderRead for PubRec { +impl PacketRead for PubRec { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { @@ -36,7 +46,7 @@ impl VariableHeaderRead for PubRec { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData("PubRec".to_string(), buf.len(), 4)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -51,7 +61,42 @@ impl VariableHeaderRead for PubRec { } } -impl VariableHeaderWrite for PubRec { +impl PacketAsyncRead for PubRec +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; + if remaining_length == 2 { + return Ok(( + Self { + packet_identifier, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties::default(), + }, + total_read_bytes, + )); + } + + let (reason_code, reason_code_read_bytes) = PubRecReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubRecProperties::async_read(stream).await?; + + total_read_bytes += reason_code_read_bytes + properties_read_bytes; + + Ok(( + Self { + packet_identifier, + properties, + reason_code, + }, + total_read_bytes, + )) + } +} + +impl PacketWrite for PubRec { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -66,6 +111,28 @@ impl VariableHeaderWrite for PubRec { Ok(()) } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubRec +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_written_bytes = 2; + self.packet_identifier.async_write(stream).await?; + + if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_written_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_written_bytes += self.reason_code.async_write(stream).await?; + } else { + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} impl WireLength for PubRec { fn wire_len(&self) -> usize { @@ -74,96 +141,18 @@ impl WireLength for PubRec { } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 3 } else { - 2 + 1 + self.properties.wire_len() + let prop_wire_len = self.properties.wire_len(); + 2 + 1 + prop_wire_len.variable_integer_len() + prop_wire_len } } } -#[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -pub struct PubRecProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} - -impl PubRecProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} - -impl MqttRead for PubRecProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - if len == 0 { - return Ok(Self::default()); - } - if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubRecProperties".to_string(), buf.len(), len)); - } - - let mut properties = PubRecProperties::default(); - - loop { - match PropertyType::from_u8(u8::read(buf)?)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(buf)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRec)), - } - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for PubRecProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let len = self.wire_len(); - - write_variable_integer(buf, len)?; - - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)? - } - - Ok(()) - } -} - -impl WireLength for PubRecProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - - len - } -} - #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubrec::{PubRec, PubRecProperties}, - reason_codes::PubRecReasonCode, - write_variable_integer, PropertyType, + PropertyType, PubRecReasonCode, VariableInteger, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -226,7 +215,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -253,7 +242,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let properties = PubRecProperties::read(&mut buf.clone().into()).unwrap(); diff --git a/mqrstt/src/packets/pubrec/properties.rs b/mqrstt/src/packets/pubrec/properties.rs new file mode 100644 index 0000000..3f3816f --- /dev/null +++ b/mqrstt/src/packets/pubrec/properties.rs @@ -0,0 +1,64 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!( + /// PubRec Properties + PubRecProperties, + ReasonString, + UserProperty +); + +impl MqttRead for PubRecProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + if len == 0 { + return Ok(Self::default()); + } + if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties = PubRecProperties::default(); + + loop { + match PropertyType::try_from(u8::read(buf)?)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(buf)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRec)), + } + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for PubRecProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + let len = self.wire_len(); + + len.write_variable_integer(buf)?; + + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)? + } + + Ok(()) + } +} diff --git a/mqrstt/src/packets/pubrec/reason_code.rs b/mqrstt/src/packets/pubrec/reason_code.rs new file mode 100644 index 0000000..6d7cf6e --- /dev/null +++ b/mqrstt/src/packets/pubrec/reason_code.rs @@ -0,0 +1,12 @@ +crate::packets::macros::reason_code!( + PubRecReasonCode, + Success, + NoMatchingSubscribers, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicNameInvalid, + PacketIdentifierInUse, + QuotaExceeded, + PayloadFormatInvalid +); diff --git a/mqrstt/src/packets/pubrel.rs b/mqrstt/src/packets/pubrel.rs deleted file mode 100644 index 79039ac..0000000 --- a/mqrstt/src/packets/pubrel.rs +++ /dev/null @@ -1,314 +0,0 @@ -use bytes::BufMut; - -use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::PubRelReasonCode, - write_variable_integer, PacketType, PropertyType, -}; - -#[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub struct PubRel { - pub packet_identifier: u16, - pub reason_code: PubRelReasonCode, - pub properties: PubRelProperties, -} - -impl PubRel { - pub fn new(packet_identifier: u16) -> Self { - Self { - packet_identifier, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties::default(), - } - } -} - -impl VariableHeaderRead for PubRel { - fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { - // reason code and properties are optional if reasoncode is success and properties empty. - if remaining_length == 2 { - Ok(Self { - packet_identifier: u16::read(&mut buf)?, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties::default(), - }) - } else if remaining_length == 3 { - Ok(Self { - packet_identifier: u16::read(&mut buf)?, - reason_code: PubRelReasonCode::read(&mut buf)?, - properties: PubRelProperties::default(), - }) - } else { - Ok(Self { - packet_identifier: u16::read(&mut buf)?, - reason_code: PubRelReasonCode::read(&mut buf)?, - properties: PubRelProperties::read(&mut buf)?, - }) - } - } -} - -impl VariableHeaderWrite for PubRel { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - buf.put_u16(self.packet_identifier); - - if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - // Nothing here - } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - self.reason_code.write(buf)?; - } else { - self.reason_code.write(buf)?; - self.properties.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for PubRel { - fn wire_len(&self) -> usize { - if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - 2 - } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - 3 - } else { - 2 + 1 + self.properties.wire_len() - } - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -pub struct PubRelProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} - -impl PubRelProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} - -impl MqttRead for PubRelProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - if len == 0 { - return Ok(Self::default()); - } - if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubRelProperties".to_string(), buf.len(), len)); - } - - let mut properties = PubRelProperties::default(); - - loop { - match PropertyType::from_u8(u8::read(buf)?)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(buf)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRel)), - } - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for PubRelProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let len = self.wire_len(); - - write_variable_integer(buf, len)?; - - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)? - } - - Ok(()) - } -} - -impl WireLength for PubRelProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - - len - } -} - -#[cfg(test)] -mod tests { - use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - pubrel::{PubRel, PubRelProperties}, - reason_codes::PubRelReasonCode, - write_variable_integer, PropertyType, - }; - use bytes::{BufMut, Bytes, BytesMut}; - - #[test] - fn test_wire_len() { - let mut pubrel = PubRel { - packet_identifier: 12, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties::default(), - }; - - let mut buf = BytesMut::new(); - - pubrel.write(&mut buf).unwrap(); - - assert_eq!(2, pubrel.wire_len()); - assert_eq!(2, buf.len()); - - pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound; - buf.clear(); - pubrel.write(&mut buf).unwrap(); - - assert_eq!(3, pubrel.wire_len()); - assert_eq!(3, buf.len()); - } - - #[test] - fn test_read_short() { - let mut expected_pubrel = PubRel { - packet_identifier: 12, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties::default(), - }; - - let mut buf = BytesMut::new(); - - expected_pubrel.write(&mut buf).unwrap(); - - assert_eq!(2, buf.len()); - - let pubrel = PubRel::read(0, 2, buf.into()).unwrap(); - - assert_eq!(expected_pubrel, pubrel); - - let mut buf = BytesMut::new(); - expected_pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound; - expected_pubrel.write(&mut buf).unwrap(); - - assert_eq!(3, buf.len()); - - let pubrel = PubRel::read(0, 3, buf.into()).unwrap(); - assert_eq!(expected_pubrel, pubrel); - } - - #[test] - fn test_read_simple_pub_rel() { - let stream = &[ - 0x00, 0x0C, // Packet identifier = 12 - 0x00, // Reason code success - 0x00, // no properties - ]; - let buf = Bytes::from(&stream[..]); - let p_ack = PubRel::read(0, 4, buf).unwrap(); - - let expected = PubRel { - packet_identifier: 12, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties::default(), - }; - - assert_eq!(expected, p_ack); - } - - #[test] - fn test_read_write_pubrel_with_properties() { - let mut buf = BytesMut::new(); - - buf.put_u16(65_535u16); - buf.put_u8(0x92); - - let mut properties = BytesMut::new(); - PropertyType::ReasonString.write(&mut properties).unwrap(); - "reason string, test 1-2-3.".write(&mut properties).unwrap(); - PropertyType::UserProperty.write(&mut properties).unwrap(); - "This is the key".write(&mut properties).unwrap(); - "This is the value".write(&mut properties).unwrap(); - PropertyType::UserProperty.write(&mut properties).unwrap(); - "Another thingy".write(&mut properties).unwrap(); - "The thingy".write(&mut properties).unwrap(); - - write_variable_integer(&mut buf, properties.len()).unwrap(); - - buf.extend(properties); - - // flags can be 0 because not used. - // remaining_length must be at least 4 - let p_ack = PubRel::read(0, buf.len(), buf.clone().into()).unwrap(); - - let mut result = BytesMut::new(); - p_ack.write(&mut result).unwrap(); - - assert_eq!(buf.to_vec(), result.to_vec()); - } - - #[test] - fn test_properties() { - let mut properties_data = BytesMut::new(); - PropertyType::ReasonString.write(&mut properties_data).unwrap(); - "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); - PropertyType::UserProperty.write(&mut properties_data).unwrap(); - "This is the key".write(&mut properties_data).unwrap(); - "This is the value".write(&mut properties_data).unwrap(); - PropertyType::UserProperty.write(&mut properties_data).unwrap(); - "Another thingy".write(&mut properties_data).unwrap(); - "The thingy".write(&mut properties_data).unwrap(); - - let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); - buf.extend(properties_data); - - let properties = PubRelProperties::read(&mut buf.clone().into()).unwrap(); - let mut result = BytesMut::new(); - properties.write(&mut result).unwrap(); - - assert_eq!(buf.to_vec(), result.to_vec()); - } - - #[test] - fn no_reason_code_or_props() { - let mut buf = BytesMut::new(); - - buf.put_u16(65_535u16); - let p_ack = PubRel::read(0, buf.len(), buf.clone().into()).unwrap(); - - let mut result = BytesMut::new(); - p_ack.write(&mut result).unwrap(); - - let expected = PubRel { - packet_identifier: 65535, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties::default(), - }; - let mut result = BytesMut::new(); - expected.write(&mut result).unwrap(); - - assert_eq!(expected, p_ack); - assert_eq!(buf.to_vec(), result.to_vec()); - } -} diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs new file mode 100644 index 0000000..c714739 --- /dev/null +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -0,0 +1,434 @@ +mod reason_code; +pub use reason_code::PubRelReasonCode; + +mod properties; +pub use properties::PubRelProperties; + +use bytes::BufMut; +use tokio::io::AsyncReadExt; + +use super::{ + error::{DeserializeError, ReadError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, +}; + +/// The [`PubRel`] (Publish Release) packet acknowledges the reception of a [`crate::packets::PubRec`] Packet. +/// +/// This user does not need to send this message, it is handled internally by the client. +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct PubRel { + pub packet_identifier: u16, + pub reason_code: PubRelReasonCode, + pub properties: PubRelProperties, +} + +impl PubRel { + pub fn new(packet_identifier: u16) -> Self { + Self { + packet_identifier, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + } + } +} + +impl PacketRead for PubRel { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { + // reason code and properties are optional if reasoncode is success and properties empty. + if remaining_length == 2 { + Ok(Self { + packet_identifier: u16::read(&mut buf)?, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }) + } else if remaining_length == 3 { + Ok(Self { + packet_identifier: u16::read(&mut buf)?, + reason_code: PubRelReasonCode::read(&mut buf)?, + properties: PubRelProperties::default(), + }) + } else { + Ok(Self { + packet_identifier: u16::read(&mut buf)?, + reason_code: PubRelReasonCode::read(&mut buf)?, + properties: PubRelProperties::read(&mut buf)?, + }) + } + } +} + +impl PacketAsyncRead for PubRel +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; + let res = if remaining_length == 2 { + Self { + packet_identifier, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + } + } else { + let (reason_code, read_bytes) = PubRelReasonCode::async_read(stream).await?; + total_read_bytes += read_bytes; + if remaining_length == 3 { + Self { + packet_identifier, + reason_code, + properties: PubRelProperties::default(), + } + } else { + let (properties, read_bytes) = PubRelProperties::async_read(stream).await?; + total_read_bytes += read_bytes; + Self { + packet_identifier, + reason_code, + properties, + } + } + }; + Ok((res, total_read_bytes)) + } +} + +impl PacketWrite for PubRel { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + buf.put_u16(self.packet_identifier); + + if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + // Nothing here + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + self.reason_code.write(buf)?; + } else { + self.reason_code.write(buf)?; + self.properties.write(buf)?; + } + Ok(()) + } +} +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubRel +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_written_bytes = 2; + self.packet_identifier.async_write(stream).await?; + + if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_written_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_written_bytes += self.reason_code.async_write(stream).await?; + } else { + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} + +impl WireLength for PubRel { + fn wire_len(&self) -> usize { + if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + 2 + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + 3 + } else { + let prop_wire_len = self.properties.wire_len(); + 2 + 1 + prop_wire_len.variable_integer_len() + prop_wire_len + } + } +} + +#[cfg(test)] +mod tests { + use crate::packets::{ + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + pubrel::{PubRel, PubRelProperties}, + PropertyType, PubRelReasonCode, VariableInteger, + }; + use bytes::{BufMut, Bytes, BytesMut}; + + #[test] + fn test_wire_len() { + let mut pubrel = PubRel { + packet_identifier: 12, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + + let mut buf = BytesMut::new(); + + pubrel.write(&mut buf).unwrap(); + + assert_eq!(2, pubrel.wire_len()); + assert_eq!(2, buf.len()); + + pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound; + buf.clear(); + pubrel.write(&mut buf).unwrap(); + + assert_eq!(3, pubrel.wire_len()); + assert_eq!(3, buf.len()); + } + + #[test] + fn test_wire_len2() { + let mut buf = BytesMut::new(); + + let prop = PubRelProperties { + reason_string: Some("reason string, test 1-2-3.".into()), // 26 + 1 + 2 + user_properties: vec![ + ("This is the key".into(), "This is the value".into()), // 32 + 1 + 2 + 2 + ("Another thingy".into(), "The thingy".into()), // 24 + 1 + 2 + 2 + ], + }; + + let len = prop.wire_len(); + // determine length of variable integer + let len_of_wire_len = len.write_variable_integer(&mut buf).unwrap(); + // clear buffer before writing actual properties + buf.clear(); + prop.write(&mut buf).unwrap(); + + assert_eq!(len + len_of_wire_len, buf.len()); + } + + #[test] + fn test_read_short() { + let mut expected_pubrel = PubRel { + packet_identifier: 12, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + + let mut buf = BytesMut::new(); + + expected_pubrel.write(&mut buf).unwrap(); + + assert_eq!(2, buf.len()); + + let pubrel = PubRel::read(0, 2, buf.into()).unwrap(); + + assert_eq!(expected_pubrel, pubrel); + + let mut buf = BytesMut::new(); + expected_pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound; + expected_pubrel.write(&mut buf).unwrap(); + + assert_eq!(3, buf.len()); + + let pubrel = PubRel::read(0, 3, buf.into()).unwrap(); + assert_eq!(expected_pubrel, pubrel); + } + + #[tokio::test] + async fn test_async_read_short() { + let mut expected_pubrel = PubRel { + packet_identifier: 12, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + + let mut buf = BytesMut::new(); + + expected_pubrel.write(&mut buf).unwrap(); + + assert_eq!(2, buf.len()); + let mut stream: &[u8] = &*buf; + + let (pubrel, read_bytes) = PubRel::async_read(0, 2, &mut stream).await.unwrap(); + + assert_eq!(expected_pubrel, pubrel); + assert_eq!(read_bytes, 2); + + let mut buf = BytesMut::new(); + expected_pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound; + expected_pubrel.write(&mut buf).unwrap(); + + assert_eq!(3, buf.len()); + let mut stream: &[u8] = &*buf; + + let (pubrel, read_bytes) = PubRel::async_read(0, 3, &mut stream).await.unwrap(); + assert_eq!(read_bytes, 3); + assert_eq!(expected_pubrel, pubrel); + } + + #[test] + fn test_read_simple_pub_rel() { + let stream = &[ + 0x00, 0x0C, // Packet identifier = 12 + 0x00, // Reason code success + 0x00, // no properties + ]; + let buf = Bytes::from(&stream[..]); + let p_ack = PubRel::read(0, 4, buf).unwrap(); + + let expected = PubRel { + packet_identifier: 12, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + + assert_eq!(expected, p_ack); + } + #[tokio::test] + async fn test_async_read_simple_pub_rel() { + let stream = &[ + 0x00, 0x0C, // Packet identifier = 12 + 0x00, // Reason code success + 0x00, // no properties + ]; + + let mut stream = stream.as_ref(); + + let (p_ack, read_bytes) = PubRel::async_read(0, 4, &mut stream).await.unwrap(); + + let expected = PubRel { + packet_identifier: 12, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + + assert_eq!(expected, p_ack); + assert_eq!(read_bytes, 4); + } + + #[test] + fn test_read_write_pubrel_with_properties() { + let mut buf = BytesMut::new(); + + buf.put_u16(65_535u16); + buf.put_u8(0x92); + + let mut properties = BytesMut::new(); + PropertyType::ReasonString.write(&mut properties).unwrap(); + "reason string, test 1-2-3.".write(&mut properties).unwrap(); + PropertyType::UserProperty.write(&mut properties).unwrap(); + "This is the key".write(&mut properties).unwrap(); + "This is the value".write(&mut properties).unwrap(); + PropertyType::UserProperty.write(&mut properties).unwrap(); + "Another thingy".write(&mut properties).unwrap(); + "The thingy".write(&mut properties).unwrap(); + + properties.len().write_variable_integer(&mut buf).unwrap(); + + buf.extend(properties); + + // flags can be 0 because not used. + // remaining_length must be at least 4 + let p_ack = PubRel::read(0, buf.len(), buf.clone().into()).unwrap(); + + let mut result = BytesMut::new(); + p_ack.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + } + + #[tokio::test] + async fn test_async_read_write_pubrel_with_properties() { + let mut buf = BytesMut::new(); + + buf.put_u16(65_535u16); + buf.put_u8(0x92); + + let mut properties = BytesMut::new(); + PropertyType::ReasonString.write(&mut properties).unwrap(); + "reason string, test 1-2-3.".write(&mut properties).unwrap(); + PropertyType::UserProperty.write(&mut properties).unwrap(); + "This is the key".write(&mut properties).unwrap(); + "This is the value".write(&mut properties).unwrap(); + PropertyType::UserProperty.write(&mut properties).unwrap(); + "Another thingy".write(&mut properties).unwrap(); + "The thingy".write(&mut properties).unwrap(); + + properties.len().write_variable_integer(&mut buf).unwrap(); + + buf.extend(properties); + + let mut stream = &*buf; + // flags can be 0 because not used. + // remaining_length must be at least 4 + let (p_ack, _) = PubRel::async_read(0, buf.len(), &mut stream).await.unwrap(); + + let mut result = BytesMut::new(); + p_ack.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + } + + #[test] + fn test_properties() { + let mut properties_data = BytesMut::new(); + PropertyType::ReasonString.write(&mut properties_data).unwrap(); + "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); + "This is the key".write(&mut properties_data).unwrap(); + "This is the value".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); + "Another thingy".write(&mut properties_data).unwrap(); + "The thingy".write(&mut properties_data).unwrap(); + + let mut buf = BytesMut::new(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); + buf.extend(properties_data); + + let properties = PubRelProperties::read(&mut buf.clone().into()).unwrap(); + let mut result = BytesMut::new(); + properties.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + } + + #[tokio::test] + async fn test_async_read_properties() { + let mut properties_data = BytesMut::new(); + PropertyType::ReasonString.write(&mut properties_data).unwrap(); + "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); + "This is the key".write(&mut properties_data).unwrap(); + "This is the value".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); + "Another thingy".write(&mut properties_data).unwrap(); + "The thingy".write(&mut properties_data).unwrap(); + + let mut buf = BytesMut::new(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); + buf.extend(properties_data); + + let (properties, read_bytes) = PubRelProperties::async_read(&mut &*buf).await.unwrap(); + let mut result = BytesMut::new(); + properties.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + assert_eq!(buf.len(), read_bytes); + } + + #[test] + fn no_reason_code_or_props() { + let mut buf = BytesMut::new(); + + buf.put_u16(65_535u16); + let p_ack = PubRel::read(0, buf.len(), buf.clone().into()).unwrap(); + + let mut result = BytesMut::new(); + p_ack.write(&mut result).unwrap(); + + let expected = PubRel { + packet_identifier: 65535, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + let mut result = BytesMut::new(); + expected.write(&mut result).unwrap(); + + assert_eq!(expected, p_ack); + assert_eq!(buf.to_vec(), result.to_vec()); + } +} diff --git a/mqrstt/src/packets/pubrel/properties.rs b/mqrstt/src/packets/pubrel/properties.rs new file mode 100644 index 0000000..09d36d3 --- /dev/null +++ b/mqrstt/src/packets/pubrel/properties.rs @@ -0,0 +1,70 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!( + /// PubRel Properties + PubRelProperties, + ReasonString, + UserProperty +); + +impl PubRelProperties { + pub fn is_empty(&self) -> bool { + self.reason_string.is_none() && self.user_properties.is_empty() + } +} + +impl MqttRead for PubRelProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + if len == 0 { + return Ok(Self::default()); + } + if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties = PubRelProperties::default(); + + loop { + match PropertyType::try_from(u8::read(buf)?)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(buf)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRel)), + } + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for PubRelProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + let len = self.wire_len(); + + len.write_variable_integer(buf)?; + + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)? + } + + Ok(()) + } +} diff --git a/mqrstt/src/packets/pubrel/reason_code.rs b/mqrstt/src/packets/pubrel/reason_code.rs new file mode 100644 index 0000000..24c44f6 --- /dev/null +++ b/mqrstt/src/packets/pubrel/reason_code.rs @@ -0,0 +1 @@ +crate::packets::macros::reason_code!(PubRelReasonCode, Success, PacketIdentifierNotFound); diff --git a/mqrstt/src/packets/reason_codes.rs b/mqrstt/src/packets/reason_codes.rs deleted file mode 100644 index a940562..0000000 --- a/mqrstt/src/packets/reason_codes.rs +++ /dev/null @@ -1,532 +0,0 @@ -use std::default; - -use bytes::{Buf, BufMut}; - -use super::error::DeserializeError; -use super::mqtt_traits::{MqttRead, MqttWrite}; - -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ConnAckReasonCode { - #[default] - Success, - - UnspecifiedError, - MalformedPacket, - ProtocolError, - ImplementationSpecificError, - UnsupportedProtocolVersion, - ClientIdentifierNotValid, - BadUsernameOrPassword, - NotAuthorized, - ServerUnavailable, - ServerBusy, - Banned, - BadAuthenticationMethod, - TopicNameInvalid, - PacketTooLarge, - QuotaExceeded, - PayloadFormatInvalid, - RetainNotSupported, - QosNotSupported, - UseAnotherServer, - ServerMoved, - ConnectionRateExceeded, -} - -impl MqttRead for ConnAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("ConAckReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(ConnAckReasonCode::Success), - 0x80 => Ok(ConnAckReasonCode::UnspecifiedError), - 0x81 => Ok(ConnAckReasonCode::MalformedPacket), - 0x82 => Ok(ConnAckReasonCode::ProtocolError), - 0x83 => Ok(ConnAckReasonCode::ImplementationSpecificError), - 0x84 => Ok(ConnAckReasonCode::UnsupportedProtocolVersion), - 0x85 => Ok(ConnAckReasonCode::ClientIdentifierNotValid), - 0x86 => Ok(ConnAckReasonCode::BadUsernameOrPassword), - 0x87 => Ok(ConnAckReasonCode::NotAuthorized), - 0x88 => Ok(ConnAckReasonCode::ServerUnavailable), - 0x89 => Ok(ConnAckReasonCode::ServerBusy), - 0x8A => Ok(ConnAckReasonCode::Banned), - 0x8C => Ok(ConnAckReasonCode::BadAuthenticationMethod), - 0x90 => Ok(ConnAckReasonCode::TopicNameInvalid), - 0x95 => Ok(ConnAckReasonCode::PacketTooLarge), - 0x97 => Ok(ConnAckReasonCode::QuotaExceeded), - 0x99 => Ok(ConnAckReasonCode::PayloadFormatInvalid), - 0x9A => Ok(ConnAckReasonCode::RetainNotSupported), - 0x9B => Ok(ConnAckReasonCode::QosNotSupported), - 0x9C => Ok(ConnAckReasonCode::UseAnotherServer), - 0x9D => Ok(ConnAckReasonCode::ServerMoved), - 0x9F => Ok(ConnAckReasonCode::ConnectionRateExceeded), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for ConnAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - ConnAckReasonCode::Success => 0x00, - ConnAckReasonCode::UnspecifiedError => 0x80, - ConnAckReasonCode::MalformedPacket => 0x81, - ConnAckReasonCode::ProtocolError => 0x82, - ConnAckReasonCode::ImplementationSpecificError => 0x83, - ConnAckReasonCode::UnsupportedProtocolVersion => 0x84, - ConnAckReasonCode::ClientIdentifierNotValid => 0x85, - ConnAckReasonCode::BadUsernameOrPassword => 0x86, - ConnAckReasonCode::NotAuthorized => 0x87, - ConnAckReasonCode::ServerUnavailable => 0x88, - ConnAckReasonCode::ServerBusy => 0x89, - ConnAckReasonCode::Banned => 0x8A, - ConnAckReasonCode::BadAuthenticationMethod => 0x8C, - ConnAckReasonCode::TopicNameInvalid => 0x90, - ConnAckReasonCode::PacketTooLarge => 0x95, - ConnAckReasonCode::QuotaExceeded => 0x97, - ConnAckReasonCode::PayloadFormatInvalid => 0x99, - ConnAckReasonCode::RetainNotSupported => 0x9A, - ConnAckReasonCode::QosNotSupported => 0x9B, - ConnAckReasonCode::UseAnotherServer => 0x9C, - ConnAckReasonCode::ServerMoved => 0x9D, - ConnAckReasonCode::ConnectionRateExceeded => 0x9F, - }; - - buf.put_u8(val); - - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum AuthReasonCode { - Success, - ContinueAuthentication, - ReAuthenticate, -} - -impl MqttRead for AuthReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("AuthReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(AuthReasonCode::Success), - 0x18 => Ok(AuthReasonCode::ContinueAuthentication), - 0x19 => Ok(AuthReasonCode::ReAuthenticate), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for AuthReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - AuthReasonCode::Success => 0x00, - AuthReasonCode::ContinueAuthentication => 0x18, - AuthReasonCode::ReAuthenticate => 0x19, - }; - - buf.put_u8(val); - - Ok(()) - } -} - -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum DisconnectReasonCode { - #[default] - NormalDisconnection, - DisconnectWithWillMessage, - UnspecifiedError, - MalformedPacket, - ProtocolError, - ImplementationSpecificError, - NotAuthorized, - ServerBusy, - ServerShuttingDown, - KeepAliveTimeout, - SessionTakenOver, - TopicFilterInvalid, - TopicNameInvalid, - ReceiveMaximumExceeded, - TopicAliasInvalid, - PacketTooLarge, - MessageRateTooHigh, - QuotaExceeded, - AdministrativeAction, - PayloadFormatInvalid, - RetainNotSupported, - QosNotSupported, - UseAnotherServer, - ServerMoved, - SharedSubscriptionsNotSupported, - ConnectionRateExceeded, - MaximumConnectTime, - SubscriptionIdentifiersNotSupported, - WildcardSubscriptionsNotSupported, -} - -impl MqttRead for DisconnectReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("DisconnectReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(DisconnectReasonCode::NormalDisconnection), - 0x04 => Ok(DisconnectReasonCode::DisconnectWithWillMessage), - 0x80 => Ok(DisconnectReasonCode::UnspecifiedError), - 0x81 => Ok(DisconnectReasonCode::MalformedPacket), - 0x82 => Ok(DisconnectReasonCode::ProtocolError), - 0x83 => Ok(DisconnectReasonCode::ImplementationSpecificError), - 0x87 => Ok(DisconnectReasonCode::NotAuthorized), - 0x89 => Ok(DisconnectReasonCode::ServerBusy), - 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), - 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), - 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), - 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), - 0x90 => Ok(DisconnectReasonCode::TopicNameInvalid), - 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), - 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), - 0x95 => Ok(DisconnectReasonCode::PacketTooLarge), - 0x96 => Ok(DisconnectReasonCode::MessageRateTooHigh), - 0x97 => Ok(DisconnectReasonCode::QuotaExceeded), - 0x98 => Ok(DisconnectReasonCode::AdministrativeAction), - 0x99 => Ok(DisconnectReasonCode::PayloadFormatInvalid), - 0x9A => Ok(DisconnectReasonCode::RetainNotSupported), - 0x9B => Ok(DisconnectReasonCode::QosNotSupported), - 0x9C => Ok(DisconnectReasonCode::UseAnotherServer), - 0x9D => Ok(DisconnectReasonCode::ServerMoved), - 0x9E => Ok(DisconnectReasonCode::SharedSubscriptionsNotSupported), - 0x9F => Ok(DisconnectReasonCode::ConnectionRateExceeded), - 0xA0 => Ok(DisconnectReasonCode::MaximumConnectTime), - 0xA1 => Ok(DisconnectReasonCode::SubscriptionIdentifiersNotSupported), - 0xA2 => Ok(DisconnectReasonCode::WildcardSubscriptionsNotSupported), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for DisconnectReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - DisconnectReasonCode::NormalDisconnection => 0x00, - DisconnectReasonCode::DisconnectWithWillMessage => 0x04, - DisconnectReasonCode::UnspecifiedError => 0x80, - DisconnectReasonCode::MalformedPacket => 0x81, - DisconnectReasonCode::ProtocolError => 0x82, - DisconnectReasonCode::ImplementationSpecificError => 0x83, - DisconnectReasonCode::NotAuthorized => 0x87, - DisconnectReasonCode::ServerBusy => 0x89, - DisconnectReasonCode::ServerShuttingDown => 0x8B, - DisconnectReasonCode::KeepAliveTimeout => 0x8D, - DisconnectReasonCode::SessionTakenOver => 0x8E, - DisconnectReasonCode::TopicFilterInvalid => 0x8F, - DisconnectReasonCode::TopicNameInvalid => 0x90, - DisconnectReasonCode::ReceiveMaximumExceeded => 0x93, - DisconnectReasonCode::TopicAliasInvalid => 0x94, - DisconnectReasonCode::PacketTooLarge => 0x95, - DisconnectReasonCode::MessageRateTooHigh => 0x96, - DisconnectReasonCode::QuotaExceeded => 0x97, - DisconnectReasonCode::AdministrativeAction => 0x98, - DisconnectReasonCode::PayloadFormatInvalid => 0x99, - DisconnectReasonCode::RetainNotSupported => 0x9A, - DisconnectReasonCode::QosNotSupported => 0x9B, - DisconnectReasonCode::UseAnotherServer => 0x9C, - DisconnectReasonCode::ServerMoved => 0x9D, - DisconnectReasonCode::SharedSubscriptionsNotSupported => 0x9E, - DisconnectReasonCode::ConnectionRateExceeded => 0x9F, - DisconnectReasonCode::MaximumConnectTime => 0xA0, - DisconnectReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, - DisconnectReasonCode::WildcardSubscriptionsNotSupported => 0xA2, - }; - - buf.put_u8(val); - - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubAckReasonCode { - Success, - NoMatchingSubscribers, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicNameInvalid, - PacketIdentifierInUse, - QuotaExceeded, - PayloadFormatInvalid, -} - -impl MqttRead for PubAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubAckReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubAckReasonCode::Success), - 0x10 => Ok(PubAckReasonCode::NoMatchingSubscribers), - 0x80 => Ok(PubAckReasonCode::UnspecifiedError), - 0x83 => Ok(PubAckReasonCode::ImplementationSpecificError), - 0x87 => Ok(PubAckReasonCode::NotAuthorized), - 0x90 => Ok(PubAckReasonCode::TopicNameInvalid), - 0x91 => Ok(PubAckReasonCode::PacketIdentifierInUse), - 0x97 => Ok(PubAckReasonCode::QuotaExceeded), - 0x99 => Ok(PubAckReasonCode::PayloadFormatInvalid), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PubAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubAckReasonCode::Success => 0x00, - PubAckReasonCode::NoMatchingSubscribers => 0x10, - PubAckReasonCode::UnspecifiedError => 0x80, - PubAckReasonCode::ImplementationSpecificError => 0x83, - PubAckReasonCode::NotAuthorized => 0x87, - PubAckReasonCode::TopicNameInvalid => 0x90, - PubAckReasonCode::PacketIdentifierInUse => 0x91, - PubAckReasonCode::QuotaExceeded => 0x97, - PubAckReasonCode::PayloadFormatInvalid => 0x99, - }; - - buf.put_u8(val); - - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubCompReasonCode { - Success, - PacketIdentifierNotFound, -} - -impl MqttRead for PubCompReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubCompReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubCompReasonCode::Success), - 0x92 => Ok(PubCompReasonCode::PacketIdentifierNotFound), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PubCompReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubCompReasonCode::Success => 0x00, - PubCompReasonCode::PacketIdentifierNotFound => 0x92, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubRecReasonCode { - Success, - NoMatchingSubscribers, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicNameInvalid, - PacketIdentifierInUse, - QuotaExceeded, - PayloadFormatInvalid, -} - -impl MqttRead for PubRecReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubRecReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubRecReasonCode::Success), - 0x10 => Ok(PubRecReasonCode::NoMatchingSubscribers), - 0x80 => Ok(PubRecReasonCode::UnspecifiedError), - 0x83 => Ok(PubRecReasonCode::ImplementationSpecificError), - 0x87 => Ok(PubRecReasonCode::NotAuthorized), - 0x90 => Ok(PubRecReasonCode::TopicNameInvalid), - 0x91 => Ok(PubRecReasonCode::PacketIdentifierInUse), - 0x97 => Ok(PubRecReasonCode::QuotaExceeded), - 0x99 => Ok(PubRecReasonCode::PayloadFormatInvalid), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PubRecReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubRecReasonCode::Success => 0x00, - PubRecReasonCode::NoMatchingSubscribers => 0x10, - PubRecReasonCode::UnspecifiedError => 0x80, - PubRecReasonCode::ImplementationSpecificError => 0x83, - PubRecReasonCode::NotAuthorized => 0x87, - PubRecReasonCode::TopicNameInvalid => 0x90, - PubRecReasonCode::PacketIdentifierInUse => 0x91, - PubRecReasonCode::QuotaExceeded => 0x97, - PubRecReasonCode::PayloadFormatInvalid => 0x99, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubRelReasonCode { - Success, - PacketIdentifierNotFound, -} - -impl MqttRead for PubRelReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubRelReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubRelReasonCode::Success), - 0x92 => Ok(PubRelReasonCode::PacketIdentifierNotFound), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PubRelReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubRelReasonCode::Success => 0x00, - PubRelReasonCode::PacketIdentifierNotFound => 0x92, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum SubAckReasonCode { - GrantedQoS0, - GrantedQoS1, - GrantedQoS2, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicFilterInvalid, - PacketIdentifierInUse, - QuotaExceeded, - SharedSubscriptionsNotSupported, - SubscriptionIdentifiersNotSupported, - WildcardSubscriptionsNotSupported, -} - -impl MqttRead for SubAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("SubAckReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(SubAckReasonCode::GrantedQoS0), - 0x01 => Ok(SubAckReasonCode::GrantedQoS1), - 0x02 => Ok(SubAckReasonCode::GrantedQoS2), - 0x80 => Ok(SubAckReasonCode::UnspecifiedError), - 0x83 => Ok(SubAckReasonCode::ImplementationSpecificError), - 0x87 => Ok(SubAckReasonCode::NotAuthorized), - 0x8F => Ok(SubAckReasonCode::TopicFilterInvalid), - 0x91 => Ok(SubAckReasonCode::PacketIdentifierInUse), - 0x97 => Ok(SubAckReasonCode::QuotaExceeded), - 0x9E => Ok(SubAckReasonCode::SharedSubscriptionsNotSupported), - 0xA1 => Ok(SubAckReasonCode::SubscriptionIdentifiersNotSupported), - 0xA2 => Ok(SubAckReasonCode::WildcardSubscriptionsNotSupported), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for SubAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - SubAckReasonCode::GrantedQoS0 => 0x00, - SubAckReasonCode::GrantedQoS1 => 0x01, - SubAckReasonCode::GrantedQoS2 => 0x02, - SubAckReasonCode::UnspecifiedError => 0x80, - SubAckReasonCode::ImplementationSpecificError => 0x83, - SubAckReasonCode::NotAuthorized => 0x87, - SubAckReasonCode::TopicFilterInvalid => 0x8F, - SubAckReasonCode::PacketIdentifierInUse => 0x91, - SubAckReasonCode::QuotaExceeded => 0x97, - SubAckReasonCode::SharedSubscriptionsNotSupported => 0x9E, - SubAckReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, - SubAckReasonCode::WildcardSubscriptionsNotSupported => 0xA2, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum UnsubAckReasonCode { - Success, - NoSubscriptionExisted, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicFilterInvalid, - PacketIdentifierInUse, -} - -impl MqttRead for UnsubAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("UnsubAckReasonCode".to_string(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(UnsubAckReasonCode::Success), - 0x11 => Ok(UnsubAckReasonCode::NoSubscriptionExisted), - 0x80 => Ok(UnsubAckReasonCode::UnspecifiedError), - 0x83 => Ok(UnsubAckReasonCode::ImplementationSpecificError), - 0x87 => Ok(UnsubAckReasonCode::NotAuthorized), - 0x8F => Ok(UnsubAckReasonCode::TopicFilterInvalid), - 0x91 => Ok(UnsubAckReasonCode::PacketIdentifierInUse), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for UnsubAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - UnsubAckReasonCode::Success => 0x00, - UnsubAckReasonCode::NoSubscriptionExisted => 0x11, - UnsubAckReasonCode::UnspecifiedError => 0x80, - UnsubAckReasonCode::ImplementationSpecificError => 0x83, - UnsubAckReasonCode::NotAuthorized => 0x87, - UnsubAckReasonCode::TopicFilterInvalid => 0x8F, - UnsubAckReasonCode::PacketIdentifierInUse => 0x91, - }; - - buf.put_u8(val); - Ok(()) - } -} diff --git a/mqrstt/src/packets/suback.rs b/mqrstt/src/packets/suback.rs deleted file mode 100644 index 3f8caa2..0000000 --- a/mqrstt/src/packets/suback.rs +++ /dev/null @@ -1,160 +0,0 @@ -use bytes::BufMut; - -use super::{ - error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::SubAckReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; - -///3.9 SUBACK – Subscribe acknowledgement -/// A SUBACK packet is sent by the Server to the Client to confirm receipt and processing of a SUBSCRIBE packet. -/// A SUBACK packet contains a list of Reason Codes, that specify the maximum QoS level that was granted or the error which was found for each Subscription that was requested by the SUBSCRIBE. -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct SubAck { - pub packet_identifier: u16, - pub properties: SubAckProperties, - pub reason_codes: Vec, -} - -impl VariableHeaderRead for SubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { - let packet_identifier = u16::read(&mut buf)?; - let properties = SubAckProperties::read(&mut buf)?; - let mut reason_codes = vec![]; - loop { - let reason_code = SubAckReasonCode::read(&mut buf)?; - - reason_codes.push(reason_code); - - if buf.is_empty() { - break; - } - } - - Ok(Self { - packet_identifier, - properties, - reason_codes, - }) - } -} - -impl VariableHeaderWrite for SubAck { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.packet_identifier); - - self.properties.write(buf)?; - for reason_code in &self.reason_codes { - reason_code.write(buf)?; - } - - Ok(()) - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct SubAckProperties { - /// 3.8.2.1.2 Subscription Identifier - /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. - pub subscription_id: Option, - - /// 3.8.2.1.3 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for SubAckProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = SubAckProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("SubAckProperties".to_string(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::SubscriptionIdentifier => { - if properties.subscription_id.is_none() { - let (subscription_id, _) = read_variable_integer(&mut properties_data)?; - - properties.subscription_id = Some(subscription_id); - } else { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); - } - } - PropertyType::UserProperty => { - properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::SubAck)), - } - - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for SubAckProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - if let Some(sub_id) = self.subscription_id { - PropertyType::SubscriptionIdentifier.write(buf)?; - write_variable_integer(buf, sub_id)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for SubAckProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(sub_id) = self.subscription_id { - len += 1 + variable_integer_len(sub_id); - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - -#[cfg(test)] -mod test { - use bytes::BytesMut; - - use super::SubAck; - use crate::packets::mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}; - - #[test] - fn read_write_suback() { - let buf = vec![ - 0x00, 0x0F, // variable header. pkid = 15 - 0x00, // Property length 0 - 0x01, // Payload reason code codes Granted QoS 1, - 0x80, // Payload Unspecified error - ]; - - let data = BytesMut::from(&buf[..]); - let sub_ack = SubAck::read(0, 0, data.clone().into()).unwrap(); - - let mut result = BytesMut::new(); - sub_ack.write(&mut result).unwrap(); - - assert_eq!(data.to_vec(), result.to_vec()); - } -} diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs new file mode 100644 index 0000000..df8b1a2 --- /dev/null +++ b/mqrstt/src/packets/suback/mod.rs @@ -0,0 +1,148 @@ +mod properties; + +pub use properties::SubAckProperties; + +mod reason_code; +pub use reason_code::SubAckReasonCode; + +use super::{ + error::SerializeError, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + VariableInteger, WireLength, +}; +use bytes::BufMut; +use tokio::io::AsyncReadExt; + +/// SubAck packet is sent by the server in response to a [`crate::packets::Subscribe`] packet. +/// +/// A SUBACK packet is sent by the Server to the Client to confirm receipt and processing of a SUBSCRIBE packet. +/// A SUBACK packet contains a list of Reason Codes, that specify the maximum QoS level that was granted or the error which was found for each Subscription that was requested by the SUBSCRIBE. +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct SubAck { + pub packet_identifier: u16, + pub properties: SubAckProperties, + pub reason_codes: Vec, +} + +impl PacketRead for SubAck { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { + let packet_identifier = u16::read(&mut buf)?; + let properties = SubAckProperties::read(&mut buf)?; + + let mut reason_codes = vec![]; + + let mut read = 2 + properties.wire_len().variable_integer_len() + properties.wire_len(); + loop { + if read >= remaining_length { + break; + } + + let reason_code = SubAckReasonCode::read(&mut buf)?; + reason_codes.push(reason_code); + read += 1; + } + + Ok(Self { + packet_identifier, + properties, + reason_codes, + }) + } +} + +impl PacketAsyncRead for SubAck +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, proproperties_read_bytes) = SubAckProperties::async_read(stream).await?; + total_read_bytes += 2 + proproperties_read_bytes; + let mut reason_codes = vec![]; + loop { + if remaining_length == total_read_bytes { + break; + } + + let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + reason_codes.push(reason_code); + } + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) + } +} + +impl PacketWrite for SubAck { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.packet_identifier); + + self.properties.write(buf)?; + for reason_code in &self.reason_codes { + reason_code.write(buf)?; + } + + Ok(()) + } +} + +impl crate::packets::mqtt_trait::PacketAsyncWrite for SubAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + total_written_bytes += self.properties.async_write(stream).await?; + + for reason_code in &self.reason_codes { + reason_code.async_write(stream).await?; + } + total_written_bytes += self.reason_codes.len(); + + Ok(total_written_bytes) + } +} + +impl WireLength for SubAck { + fn wire_len(&self) -> usize { + 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + self.reason_codes.len() + } +} + +#[cfg(test)] +mod test { + use bytes::BytesMut; + + use super::SubAck; + use crate::packets::mqtt_trait::{PacketRead, PacketWrite}; + + #[test] + fn read_write_suback() { + let buf = vec![ + 0x00, 0x0F, // variable header. pkid = 15 + 0x00, // Property length 0 + 0x01, // Payload reason code codes Granted QoS 1, + 0x80, // Payload Unspecified error + ]; + + let data = BytesMut::from(&buf[..]); + let sub_ack = SubAck::read(0, 5, data.clone().into()).unwrap(); + + let mut result = BytesMut::new(); + sub_ack.write(&mut result).unwrap(); + + assert_eq!(data.to_vec(), result.to_vec()); + } +} diff --git a/mqrstt/src/packets/suback/properties.rs b/mqrstt/src/packets/suback/properties.rs new file mode 100644 index 0000000..15f6997 --- /dev/null +++ b/mqrstt/src/packets/suback/properties.rs @@ -0,0 +1,69 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, +}; + +use crate::packets::primitive::VariableInteger; + +crate::packets::macros::define_properties!( + /// SubAck Properties + SubAckProperties, + SubscriptionIdentifier, + UserProperty +); + +impl MqttRead for SubAckProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = SubAckProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::SubscriptionIdentifier => { + if properties.subscription_identifier.is_none() { + let (subscription_id, _) = VariableInteger::read_variable_integer(&mut properties_data)?; + + properties.subscription_identifier = Some(subscription_id); + } else { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); + } + } + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::SubAck)), + } + + if properties_data.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for SubAckProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + if let Some(sub_id) = self.subscription_identifier { + PropertyType::SubscriptionIdentifier.write(buf)?; + sub_id.write_variable_integer(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} diff --git a/mqrstt/src/packets/suback/reason_code.rs b/mqrstt/src/packets/suback/reason_code.rs new file mode 100644 index 0000000..0b89706 --- /dev/null +++ b/mqrstt/src/packets/suback/reason_code.rs @@ -0,0 +1,15 @@ +crate::packets::macros::reason_code!( + SubAckReasonCode, + GrantedQoS0, + GrantedQoS1, + GrantedQoS2, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicFilterInvalid, + PacketIdentifierInUse, + QuotaExceeded, + SharedSubscriptionsNotSupported, + SubscriptionIdentifiersNotSupported, + WildcardSubscriptionsNotSupported +); diff --git a/mqrstt/src/packets/subscribe.rs b/mqrstt/src/packets/subscribe/mod.rs similarity index 66% rename from mqrstt/src/packets/subscribe.rs rename to mqrstt/src/packets/subscribe/mod.rs index 80d994f..deb39ab 100644 --- a/mqrstt/src/packets/subscribe.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -1,12 +1,21 @@ +mod properties; + +pub use properties::SubscribeProperties; +use tokio::io::AsyncReadExt; + use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketValidation, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}, + QoS, VariableInteger, }; use bytes::{Buf, BufMut}; +/// Used to subscribe to topic(s). +/// +/// Multiple topics can be subscribed from at once. +/// For convenience [`SubscribeTopics`] is provided. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Subscribe { pub packet_identifier: u16, @@ -24,7 +33,7 @@ impl Subscribe { } } -impl VariableHeaderRead for Subscribe { +impl PacketRead for Subscribe { fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = SubscribeProperties::read(&mut buf)?; @@ -48,7 +57,40 @@ impl VariableHeaderRead for Subscribe { } } -impl VariableHeaderWrite for Subscribe { +impl PacketAsyncRead for Subscribe +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, proproperties_read_bytes) = SubscribeProperties::async_read(stream).await?; + total_read_bytes += 2 + proproperties_read_bytes; + + let mut topics = vec![]; + loop { + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + let (options, options_read_bytes) = SubscriptionOptions::async_read(stream).await?; + total_read_bytes += topic_read_bytes + options_read_bytes; + topics.push((topic, options)); + + if remaining_length >= total_read_bytes { + break; + } + } + + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) + } +} + +impl PacketWrite for Subscribe { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -62,11 +104,30 @@ impl VariableHeaderWrite for Subscribe { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for Subscribe +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + total_written_bytes += self.properties.async_write(stream).await?; + for (topic, options) in &self.topics { + total_written_bytes += topic.async_write(stream).await?; + total_written_bytes += options.async_write(stream).await?; + } + Ok(total_written_bytes) + } +} + impl WireLength for Subscribe { fn wire_len(&self) -> usize { let mut len = 2; let properties_len = self.properties.wire_len(); - len += properties_len + variable_integer_len(properties_len); + len += properties_len + properties_len.variable_integer_len(); for topic in &self.topics { len += topic.0.wire_len() + 1; } @@ -88,85 +149,6 @@ impl PacketValidation for Subscribe { } } -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct SubscribeProperties { - /// 3.8.2.1.2 Subscription Identifier - /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. - pub subscription_id: Option, - - /// 3.8.2.1.3 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for SubscribeProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = SubscribeProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("SubscribeProperties".to_string(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::SubscriptionIdentifier => { - if properties.subscription_id.is_none() { - let (subscription_id, _) = read_variable_integer(&mut properties_data)?; - - properties.subscription_id = Some(subscription_id); - } else { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); - } - } - PropertyType::UserProperty => { - properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Subscribe)), - } - - if properties_data.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for SubscribeProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - if let Some(sub_id) = self.subscription_id { - PropertyType::SubscriptionIdentifier.write(buf)?; - write_variable_integer(buf, sub_id)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for SubscribeProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(sub_id) = self.subscription_id { - len += 1 + variable_integer_len(sub_id); - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct SubscriptionOptions { pub retain_handling: RetainHandling, @@ -189,7 +171,7 @@ impl Default for SubscriptionOptions { impl MqttRead for SubscriptionOptions { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("SubscriptionOptions".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } let byte = buf.get_u8(); @@ -210,6 +192,29 @@ impl MqttRead for SubscriptionOptions { } } +impl MqttAsyncRead for SubscriptionOptions +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let byte = stream.read_u8().await?; + + let retain_handling_part = (byte & 0b00110000) >> 4; + let retain_as_publish_part = (byte & 0b00001000) >> 3; + let no_local_part = (byte & 0b00000100) >> 2; + let qos_part = byte & 0b00000011; + + let options = Self { + retain_handling: RetainHandling::from_u8(retain_handling_part)?, + retain_as_publish: retain_as_publish_part != 0, + no_local: no_local_part != 0, + qos: QoS::from_u8(qos_part)?, + }; + + Ok((options, 1)) + } +} + impl MqttWrite for SubscriptionOptions { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { let byte = (self.retain_handling.into_u8() << 4) | ((self.retain_as_publish as u8) << 3) | ((self.no_local as u8) << 2) | self.qos.into_u8(); @@ -219,10 +224,30 @@ impl MqttWrite for SubscriptionOptions { } } +impl MqttAsyncWrite for SubscriptionOptions +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use tokio::io::AsyncWriteExt; + async move { + let byte = (self.retain_handling.into_u8() << 4) | ((self.retain_as_publish as u8) << 3) | ((self.no_local as u8) << 2) | self.qos.into_u8(); + stream.write_u8(byte).await?; + Ok(1) + } + } +} + +/// Controls how retained messages are handled +/// +/// Used when a new subscription is established. Here are the three options for retain handling: #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RetainHandling { + /// Send Retained Messages at Subscription: This is the default behavior. When a client subscribes to a topic, the broker sends any retained messages for that topic immediately. ZERO, + /// Send Retained Messages Only for New Subscriptions: Retained messages are sent only if the subscription did not previously exist. ONE, + /// Do Not Send Retained Messages: Retained messages are not sent when the subscription is established TWO, } @@ -312,7 +337,7 @@ where macro_rules! impl_subscription { ($t:ty) => { - impl From<$t> for Subscription { + impl From<$t> for SubscribeTopics { #[inline] fn from(value: $t) -> Self { Self(vec![IntoSingleSubscription::into(value)]) @@ -321,19 +346,19 @@ macro_rules! impl_subscription { }; } -pub struct Subscription(pub Vec<(Box, SubscriptionOptions)>); +pub struct SubscribeTopics(pub Vec<(Box, SubscriptionOptions)>); // -------------------- Simple types -------------------- impl_subscription!(&str); impl_subscription!(&String); impl_subscription!(String); impl_subscription!(Box); -impl From<&(&str, QoS)> for Subscription { +impl From<&(&str, QoS)> for SubscribeTopics { fn from(value: &(&str, QoS)) -> Self { Self(vec![IntoSingleSubscription::into(value)]) } } -impl From<(T, QoS)> for Subscription +impl From<(T, QoS)> for SubscribeTopics where (T, QoS): IntoSingleSubscription, { @@ -341,7 +366,7 @@ where Self(vec![IntoSingleSubscription::into(value)]) } } -impl From<(T, SubscriptionOptions)> for Subscription +impl From<(T, SubscriptionOptions)> for SubscribeTopics where (T, SubscriptionOptions): IntoSingleSubscription, { @@ -350,38 +375,38 @@ where } } // -------------------- Arrays -------------------- -impl From<&[T; S]> for Subscription +impl From<&[T; S]> for SubscribeTopics where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &[T; S]) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } // -------------------- Slices -------------------- -impl From<&[T]> for Subscription +impl From<&[T]> for SubscribeTopics where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &[T]) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } // -------------------- Vecs -------------------- -impl From> for Subscription +impl From> for SubscribeTopics where T: IntoSingleSubscription, { fn from(value: Vec) -> Self { - Self(value.into_iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.into_iter().map(IntoSingleSubscription::into).collect()) } } -impl From<&Vec> for Subscription +impl From<&Vec> for SubscribeTopics where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &Vec) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } @@ -390,7 +415,7 @@ mod tests { use bytes::{Bytes, BytesMut}; use crate::packets::{ - mqtt_traits::{MqttRead, VariableHeaderRead, VariableHeaderWrite}, + mqtt_trait::{MqttRead, PacketRead, PacketWrite}, Packet, }; diff --git a/mqrstt/src/packets/subscribe/properties.rs b/mqrstt/src/packets/subscribe/properties.rs new file mode 100644 index 0000000..6703e99 --- /dev/null +++ b/mqrstt/src/packets/subscribe/properties.rs @@ -0,0 +1,67 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!( + /// Subscribe Properties + SubscribeProperties, + SubscriptionIdentifier, + UserProperty +); + +impl MqttRead for SubscribeProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = SubscribeProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::SubscriptionIdentifier => { + if properties.subscription_identifier.is_none() { + let (subscription_id, _) = VariableInteger::read_variable_integer(&mut properties_data)?; + + properties.subscription_identifier = Some(subscription_id); + } else { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); + } + } + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Subscribe)), + } + + if properties_data.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for SubscribeProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + if let Some(sub_id) = self.subscription_identifier { + PropertyType::SubscriptionIdentifier.write(buf)?; + sub_id.write_variable_integer(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} diff --git a/mqrstt/src/packets/unsuback.rs b/mqrstt/src/packets/unsuback.rs deleted file mode 100644 index da9e447..0000000 --- a/mqrstt/src/packets/unsuback.rs +++ /dev/null @@ -1,149 +0,0 @@ -use bytes::BufMut; - -use super::error::{DeserializeError, SerializeError}; -use super::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; -use super::{read_variable_integer, reason_codes::UnsubAckReasonCode, write_variable_integer, PacketType, PropertyType}; - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct UnsubAck { - pub packet_identifier: u16, - pub properties: UnsubAckProperties, - pub reason_codes: Vec, -} - -impl VariableHeaderRead for UnsubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { - let packet_identifier = u16::read(&mut buf)?; - let properties = UnsubAckProperties::read(&mut buf)?; - let mut reason_codes = vec![]; - loop { - let reason_code = UnsubAckReasonCode::read(&mut buf)?; - - reason_codes.push(reason_code); - - if buf.is_empty() { - break; - } - } - - Ok(Self { - packet_identifier, - properties, - reason_codes, - }) - } -} - -impl VariableHeaderWrite for UnsubAck { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.packet_identifier); - self.properties.write(buf)?; - for reason_code in &self.reason_codes { - reason_code.write(buf)?; - } - Ok(()) - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct UnsubAckProperties { - /// 3.11.2.1.2 Reason String - /// 31 (0x1F) Byte, Identifier of the Reason String. - pub reason_string: Option>, - - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for UnsubAckProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = UnsubAckProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("UnsubAckProperties".to_string(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::ReasonString => { - if properties.reason_string.is_none() { - properties.reason_string = Some(Box::::read(&mut properties_data)?); - } else { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); - } - } - PropertyType::UserProperty => { - properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::UnsubAck)), - } - - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for UnsubAckProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for UnsubAckProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += 1 + reason_string.wire_len(); - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - -#[cfg(test)] -mod tests { - use bytes::{Bytes, BytesMut}; - - use crate::packets::{ - mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}, - unsuback::UnsubAck, - }; - - #[test] - fn read_write_unsub_ack() { - // let entire_mqtt_packet = [0xb0, 0x04, 0x35, 0xd7, 0x00, 0x00]; - let unsub_ack = [0x35, 0xd7, 0x00, 0x00]; - - let mut bufmut = BytesMut::new(); - bufmut.extend(&unsub_ack[..]); - - let buf: Bytes = bufmut.into(); - - let s = UnsubAck::read(0xb0, 4, buf.clone()).unwrap(); - - let mut result = BytesMut::new(); - s.write(&mut result).unwrap(); - - assert_eq!(buf.to_vec(), result.to_vec()); - } -} diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs new file mode 100644 index 0000000..9823bed --- /dev/null +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -0,0 +1,150 @@ +mod properties; +pub use properties::UnsubAckProperties; + +mod reason_code; +pub use reason_code::UnsubAckReasonCode; + +use crate::packets::mqtt_trait::MqttAsyncRead; + +use bytes::BufMut; + +use tokio::io::AsyncReadExt; + +use super::error::SerializeError; +use super::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite}; +use super::{PacketAsyncRead, VariableInteger, WireLength}; + +/// UnsubAck packet is sent by the server in response to an [`crate::packets::Unsubscribe`] packet. +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct UnsubAck { + pub packet_identifier: u16, + pub properties: UnsubAckProperties, + pub reason_codes: Vec, +} + +impl PacketRead for UnsubAck { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { + let packet_identifier = u16::read(&mut buf)?; + let properties = UnsubAckProperties::read(&mut buf)?; + let mut reason_codes = vec![]; + + let mut read = 2 + properties.wire_len().variable_integer_len() + properties.wire_len(); + loop { + if read == remaining_length { + break; + } + + let reason_code = UnsubAckReasonCode::read(&mut buf)?; + reason_codes.push(reason_code); + read += 1; + } + + Ok(Self { + packet_identifier, + properties, + reason_codes, + }) + } +} + +impl PacketAsyncRead for UnsubAck +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 2; + let packet_identifier = stream.read_u16().await?; + + let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; + + let mut reason_codes = vec![]; + loop { + if total_read_bytes >= remaining_length { + break; + } + + let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + + reason_codes.push(reason_code); + } + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) + } +} + +impl PacketWrite for UnsubAck { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.packet_identifier); + self.properties.write(buf)?; + for reason_code in &self.reason_codes { + reason_code.write(buf)?; + } + Ok(()) + } +} + +impl crate::packets::mqtt_trait::PacketAsyncWrite for UnsubAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + total_written_bytes += self.properties.async_write(stream).await?; + + for reason_code in &self.reason_codes { + reason_code.async_write(stream).await?; + } + total_written_bytes += self.reason_codes.len(); + + Ok(total_written_bytes) + } + } +} + +impl WireLength for UnsubAck { + fn wire_len(&self) -> usize { + 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + self.reason_codes.len() + } +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + + use crate::packets::{ + mqtt_trait::{PacketRead, PacketWrite}, + unsuback::UnsubAck, + }; + + #[test] + fn read_write_unsub_ack() { + // let entire_mqtt_packet = [0xb0, 0x04, 0x35, 0xd7, 0x00, 0x00]; + let unsub_ack = [0x35, 0xd7, 0x00, 0x00]; + + let mut bufmut = BytesMut::new(); + bufmut.extend(&unsub_ack[..]); + + let buf: Bytes = bufmut.into(); + + let s = UnsubAck::read(0xb0, 4, buf.clone()).unwrap(); + + let mut result = BytesMut::new(); + s.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + } +} diff --git a/mqrstt/src/packets/unsuback/properties.rs b/mqrstt/src/packets/unsuback/properties.rs new file mode 100644 index 0000000..d462c50 --- /dev/null +++ b/mqrstt/src/packets/unsuback/properties.rs @@ -0,0 +1,63 @@ +use crate::packets::error::DeserializeError; +use crate::packets::mqtt_trait::{MqttRead, MqttWrite, WireLength}; +use crate::packets::{PacketType, PropertyType, VariableInteger}; + +crate::packets::macros::define_properties!( + /// UnsubAck Properties + UnsubAckProperties, + ReasonString, + UserProperty +); + +impl MqttRead for UnsubAckProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = UnsubAckProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::ReasonString => { + if properties.reason_string.is_none() { + properties.reason_string = Some(Box::::read(&mut properties_data)?); + } else { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); + } + } + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::UnsubAck)), + } + + if properties_data.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for UnsubAckProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} diff --git a/mqrstt/src/packets/unsuback/reason_code.rs b/mqrstt/src/packets/unsuback/reason_code.rs new file mode 100644 index 0000000..9ac2036 --- /dev/null +++ b/mqrstt/src/packets/unsuback/reason_code.rs @@ -0,0 +1,10 @@ +crate::packets::macros::reason_code!( + UnsubAckReasonCode, + Success, + NoSubscriptionExisted, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicFilterInvalid, + PacketIdentifierInUse +); diff --git a/mqrstt/src/packets/unsubscribe.rs b/mqrstt/src/packets/unsubscribe/mod.rs similarity index 67% rename from mqrstt/src/packets/unsubscribe.rs rename to mqrstt/src/packets/unsubscribe/mod.rs index 19e08f4..e316623 100644 --- a/mqrstt/src/packets/unsubscribe.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -1,13 +1,20 @@ +mod properties; +pub use properties::UnsubscribeProperties; + use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; -use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketValidation, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; +use crate::packets::mqtt_trait::MqttAsyncRead; + +use super::mqtt_trait::{MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; +use super::VariableInteger; use bytes::BufMut; +use tokio::io::AsyncReadExt; #[derive(Debug, Clone, PartialEq, Eq)] +/// Used to unsubscribe from topic(s). +/// +/// Multiple topics can be unsubscribed from at once. +/// For convenience [`UnsubscribeTopics`] is provided. pub struct Unsubscribe { pub packet_identifier: u16, pub properties: UnsubscribeProperties, @@ -24,7 +31,7 @@ impl Unsubscribe { } } -impl VariableHeaderRead for Unsubscribe { +impl PacketRead for Unsubscribe { fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = UnsubscribeProperties::read(&mut buf)?; @@ -47,7 +54,40 @@ impl VariableHeaderRead for Unsubscribe { } } -impl VariableHeaderWrite for Unsubscribe { +impl PacketAsyncRead for Unsubscribe +where + S: tokio::io::AsyncRead + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, properties_read_bytes) = UnsubscribeProperties::async_read(stream).await?; + total_read_bytes += 2 + properties_read_bytes; + + let mut topics = vec![]; + loop { + let (topic, topic_read_size) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_size; + + topics.push(topic); + + if total_read_bytes >= remaining_length { + break; + } + } + + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) + } +} + +impl PacketWrite for Unsubscribe { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); self.properties.write(buf)?; @@ -59,10 +99,30 @@ impl VariableHeaderWrite for Unsubscribe { Ok(()) } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for Unsubscribe +where + S: tokio::io::AsyncWrite + Unpin, +{ + async fn async_write(&self, stream: &mut S) -> Result { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + total_written_bytes += self.properties.async_write(stream).await?; + + for topic in &self.topics { + total_written_bytes += topic.async_write(stream).await?; + } + + Ok(total_written_bytes) + } +} impl WireLength for Unsubscribe { fn wire_len(&self) -> usize { - let mut len = 2 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len(); + let mut len = 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len(); for topic in &self.topics { len += topic.wire_len(); } @@ -84,63 +144,6 @@ impl PacketValidation for Unsubscribe { } } -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct UnsubscribeProperties { - pub user_properties: Vec<(String, String)>, -} - -impl MqttRead for UnsubscribeProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = UnsubscribeProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("UnsubscribeProperties".to_string(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::UserProperty => { - properties.user_properties.push((String::read(&mut properties_data)?, String::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Unsubscribe)), - } - - if properties_data.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for UnsubscribeProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for UnsubscribeProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - trait IntoUnsubscribeTopic { fn into(value: Self) -> Box; } @@ -202,7 +205,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &[T; S]) -> Self { - Self(value.iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } // -------------------- Slices -------------------- @@ -211,7 +214,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &[T]) -> Self { - Self(value.iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } impl From<&[&str]> for UnsubscribeTopics { @@ -229,26 +232,13 @@ where } } -// impl From<&[&T]> for UnsubscribeTopics -// where -// SingleUnsubscribeTopic: for<'any> From<&'any T>, -// { -// fn from(value: &[&T]) -> Self { -// Self( -// value -// .iter() -// .map(|val| SingleUnsubscribeTopic::from(val).0) -// .collect(), -// ) -// } -// } // -------------------- Vecs -------------------- impl From> for UnsubscribeTopics where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: Vec) -> Self { - Self(value.into_iter().map(|val| IntoUnsubscribeTopic::into(&val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } @@ -257,7 +247,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &Vec) -> Self { - Self(value.into_iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } @@ -280,7 +270,7 @@ mod tests { use bytes::{Bytes, BytesMut}; - use crate::packets::mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}; + use crate::packets::mqtt_trait::{PacketRead, PacketWrite}; use super::Unsubscribe; diff --git a/mqrstt/src/packets/unsubscribe/properties.rs b/mqrstt/src/packets/unsubscribe/properties.rs new file mode 100644 index 0000000..510e96d --- /dev/null +++ b/mqrstt/src/packets/unsubscribe/properties.rs @@ -0,0 +1,55 @@ +use crate::packets::VariableInteger; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, +}; + +crate::packets::macros::define_properties!( + /// Unsubscribe Properties + UnsubscribeProperties, + UserProperty +); + +impl MqttRead for UnsubscribeProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = UnsubscribeProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Unsubscribe)), + } + + if properties_data.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for UnsubscribeProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} diff --git a/mqrstt/src/smol/network.rs b/mqrstt/src/smol/network.rs index 20ef952..3eed906 100644 --- a/mqrstt/src/smol/network.rs +++ b/mqrstt/src/smol/network.rs @@ -9,14 +9,13 @@ use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::ConnectionError; use crate::packets::error::ReadBytes; -use crate::packets::reason_codes::DisconnectReasonCode; -use crate::packets::{Disconnect, Packet, PacketType}; +use crate::packets::{Disconnect, DisconnectReasonCode, Packet, PacketType}; use crate::NetworkStatus; -use crate::{AsyncEventHandlerMut, StateHandler}; +use crate::{AsyncEventHandler, StateHandler}; use super::stream::Stream; -/// [`Network`] reads and writes to the network based on tokios [`::smol::io::AsyncReadExt`] [`::smol::io::AsyncWriteExt`]. +/// [`Network`] reads and writes to the network based on tokios [`::smol::io::AsyncRead`] [`::smol::io::AsyncWrite`]. /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. /// The most import thing to remember is that you have to provide a new stream after the previous has failed. /// (i.e. you need to reconnect after any expected or unexpected disconnect). @@ -39,7 +38,7 @@ pub struct Network { } impl Network { - pub fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { + pub(crate) fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { let state_handler = StateHandler::new(&options, apkids); Self { handler: PhantomData, @@ -62,8 +61,8 @@ impl Network { impl Network where - H: AsyncEventHandlerMut, - S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, + H: AsyncEventHandler, + S: smol::io::AsyncRead + smol::io::AsyncWrite + Sized + Unpin, { /// Initializes an MQTT connection with the provided configuration an stream pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { diff --git a/mqrstt/src/smol/stream.rs b/mqrstt/src/smol/stream.rs index a24c1a6..2f1cbd8 100644 --- a/mqrstt/src/smol/stream.rs +++ b/mqrstt/src/smol/stream.rs @@ -10,8 +10,7 @@ use tracing::trace; use crate::packets::ConnAck; use crate::packets::{ error::ReadBytes, - reason_codes::ConnAckReasonCode, - {FixedHeader, Packet}, + ConnAckReasonCode, {FixedHeader, Packet}, }; use crate::{connect_options::ConnectOptions, error::ConnectionError}; @@ -40,7 +39,7 @@ impl Stream { self.read_buffer.advance(header_length); let buf = self.read_buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; + let read_packet = Packet::read_packet(header, buf.into())?; #[cfg(feature = "logs")] trace!("Read packet from network {}", read_packet); @@ -96,7 +95,7 @@ where let buf = self.read_buffer.split_to(header.remaining_length); - return Packet::read(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); + return Packet::read_packet(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); } } diff --git a/mqrstt/src/state.rs b/mqrstt/src/state.rs index 3c52951..4ece423 100644 --- a/mqrstt/src/state.rs +++ b/mqrstt/src/state.rs @@ -34,7 +34,7 @@ pub struct State { impl State { pub fn new(receive_maximum: u16, apkid: AvailablePacketIds) -> Self { - let state = Self { + Self { apkid, outgoing_sub: Mutex::new(BTreeSet::new()), @@ -45,9 +45,7 @@ impl State { outgoing_pub_order: Mutex::new(VecDeque::new()), outgoing_rel: Mutex::new(BTreeSet::new()), incoming_pub: Mutex::new(BTreeSet::new()), - }; - - state + } } pub fn make_pkid_available(&self, pkid: u16) -> Result<(), HandlerError> { diff --git a/mqrstt/src/state_handler.rs b/mqrstt/src/state_handler.rs index 03ea6fa..ecf7fd6 100644 --- a/mqrstt/src/state_handler.rs +++ b/mqrstt/src/state_handler.rs @@ -1,7 +1,6 @@ use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::HandlerError; -use crate::packets::reason_codes::{ConnAckReasonCode, PubAckReasonCode, PubRecReasonCode}; use crate::packets::PubComp; use crate::packets::PubRec; use crate::packets::PubRel; @@ -12,6 +11,7 @@ use crate::packets::Subscribe; use crate::packets::UnsubAck; use crate::packets::Unsubscribe; use crate::packets::{ConnAck, Disconnect}; +use crate::packets::{ConnAckReasonCode, PubAckReasonCode, PubRecReasonCode}; use crate::packets::{Packet, PacketType}; use crate::packets::{PubAck, PubAckProperties}; use crate::state::State; @@ -20,7 +20,7 @@ use crate::state::State; use tracing::{debug, error, info, warn}; /// Eventloop with all the state of a connection -pub struct StateHandler { +pub(crate) struct StateHandler { state: State, clean_start: bool, } @@ -188,6 +188,7 @@ impl StateHandler { _a => { #[cfg(test)] unreachable!("Was given unexpected packet {:?} ", _a); + #[cfg(not(test))] Ok(()) } } @@ -247,8 +248,8 @@ mod handler_tests { use crate::{ available_packet_ids::AvailablePacketIds, packets::{ - reason_codes::{PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode, UnsubAckReasonCode}, - Packet, QoS, UnsubAck, UnsubAckProperties, {PubComp, PubCompProperties}, {PubRec, PubRecProperties}, {PubRel, PubRelProperties}, {SubAck, SubAckProperties}, + Packet, PubComp, PubCompProperties, PubCompReasonCode, PubRec, PubRecProperties, PubRecReasonCode, PubRel, PubRelProperties, PubRelReasonCode, QoS, SubAck, SubAckProperties, + SubAckReasonCode, UnsubAck, UnsubAckProperties, UnsubAckReasonCode, }, tests::test_packets::{create_connack_packet, create_puback_packet, create_publish_packet, create_subscribe_packet, create_unsubscribe_packet}, ConnectOptions, StateHandler, diff --git a/mqrstt/src/tests/test_bytes.rs b/mqrstt/src/tests/test_bytes.rs index 93a35ef..83fcc41 100644 --- a/mqrstt/src/tests/test_bytes.rs +++ b/mqrstt/src/tests/test_bytes.rs @@ -2,7 +2,7 @@ use rstest::*; use bytes::BytesMut; -use crate::packets::{mqtt_traits::WireLength, Packet}; +use crate::packets::{mqtt_trait::WireLength, Packet}; fn publish_packet() -> Vec { const PUBLISH_BYTES: [u8; 79] = [ @@ -52,12 +52,12 @@ pub fn subscribe_packet() -> Vec { fn publish_packet_test(#[case] bytes: Vec) { let mut read_buffer = BytesMut::from_iter(bytes.iter()); let mut write_buffer = BytesMut::new(); - let packet = Packet::read_from_buffer(&mut read_buffer).unwrap(); + let packet = Packet::read(&mut read_buffer).unwrap(); packet.write(&mut write_buffer).unwrap(); assert_eq!(bytes.len(), write_buffer.len()); - let packet_from_write_buffer = Packet::read_from_buffer(&mut write_buffer).unwrap(); + let packet_from_write_buffer = Packet::read(&mut write_buffer).unwrap(); assert_eq!(packet, packet_from_write_buffer); } @@ -68,7 +68,7 @@ fn test_connect() { let mut read_buffer = BytesMut::from_iter(bytes.iter()); let mut write_buffer = BytesMut::new(); - let packet = Packet::read_from_buffer(&mut read_buffer).unwrap(); + let packet = Packet::read(&mut read_buffer).unwrap(); packet.write(&mut write_buffer).unwrap(); if let Packet::Connect(p) = &packet { @@ -79,7 +79,7 @@ fn test_connect() { assert_eq!(bytes.len(), write_buffer.len()); assert_eq!(bytes, write_buffer.to_vec()); - let packet_from_write_buffer = Packet::read_from_buffer(&mut write_buffer).unwrap(); + let packet_from_write_buffer = Packet::read(&mut write_buffer).unwrap(); assert_eq!(packet, packet_from_write_buffer); } @@ -93,7 +93,7 @@ fn test_connect() { fn test_equal_read_write_packet_from_bytes(#[case] bytes: Vec) { let mut read_buffer = BytesMut::from_iter(bytes.iter()); let mut write_buffer = BytesMut::new(); - let packet = Packet::read_from_buffer(&mut read_buffer).unwrap(); + let packet = Packet::read(&mut read_buffer).unwrap(); packet.write(&mut write_buffer).unwrap(); assert_eq!(bytes, write_buffer.to_vec()); diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index 747b1d1..f23a80a 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -1,13 +1,150 @@ -use bytes::Bytes; - use rstest::*; -use crate::packets::{ - reason_codes::{DisconnectReasonCode, PubAckReasonCode}, - ConnAck, Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, Publish, PublishProperties, QoS, Subscribe, Subscription, Unsubscribe, -}; +use crate::packets::*; + +pub fn connack_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[ + 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, + ]; + + let expected = ConnAck { + connack_flags: ConnAckFlags { session_present: true }, + reason_code: ConnAckReasonCode::Success, + connack_properties: ConnAckProperties { + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: Some(true), + maximum_packet_size: Some(1048576), + assigned_client_identifier: None, + topic_alias_maximum: Some(65535), + reason_string: None, + user_properties: vec![], + wildcards_available: Some(true), + subscription_ids_available: Some(true), + shared_subscription_available: Some(true), + server_keep_alive: None, + response_info: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }, + }; + + (packet, Packet::ConnAck(expected)) +} + +pub fn disconnect_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0xe0, 0x02, 0x8e, 0x00]; + + let expected = Disconnect { + reason_code: DisconnectReasonCode::SessionTakenOver, + properties: DisconnectProperties { + session_expiry_interval: None, + reason_string: None, + user_properties: vec![], + server_reference: None, + }, + }; + + (packet, Packet::Disconnect(expected)) +} + +pub fn ping_req_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0xc0, 0x00]; + + (packet, Packet::PingReq) +} + +pub fn ping_resp_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0xd0, 0x00]; + + (packet, Packet::PingResp) +} +pub fn publish_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[ + 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, 0x01, 0x09, 0x00, 0x04, + 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, + ]; + + let expected = Publish { + dup: false, + qos: QoS::ExactlyOnce, + retain: true, + topic: "test/123/test/blabla".into(), + packet_identifier: Some(13779), + publish_properties: PublishProperties { + payload_format_indicator: Some(1), + message_expiry_interval: None, + topic_alias: None, + response_topic: None, + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], + user_properties: vec![], + content_type: None, + }, + payload: b"".to_vec(), + }; + + (packet, Packet::Publish(expected)) +} + +pub fn pubrel_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0x62, 0x02, 0x35, 0xd3]; -fn publish_packet_1() -> Packet { + let expected = PubRel { + packet_identifier: 13779, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: None, + user_properties: vec![], + }, + }; + + (packet, Packet::PubRel(expected)) +} + +pub fn pubrel_smallest_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0x62, 0x02, 0x35, 0xd3]; + + let expected = PubRel { + packet_identifier: 13779, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: None, + user_properties: vec![], + }, + }; + + (packet, Packet::PubRel(expected)) +} + +pub fn connect_case() -> Packet { + let connect = Connect { + protocol_version: ProtocolVersion::V5, + clean_start: true, + last_will: Some(LastWill::new(QoS::ExactlyOnce, true, "will/topic", b"will payload".to_vec())), + username: Some("ThisIsTheUsername".into()), + password: Some("ThisIsThePassword".into()), + keep_alive: 60, + connect_properties: ConnectProperties { + session_expiry_interval: Some(5), + receive_maximum: Some(10), + maximum_packet_size: Some(100), + topic_alias_maximum: Some(10), + user_properties: vec![("test".into(), "test".into()), ("test2".into(), "test2".into())], + authentication_method: Some("AuthenticationMethod".into()), + authentication_data: Some(b"AuthenticationData".to_vec()), + request_response_information: Some(0), + request_problem_information: Some(1), + }, + client_id: "ThisIsTheClientID".into(), + }; + + Packet::Connect(connect) +} + +pub fn publish_packet_1() -> Packet { Packet::Publish(Publish { dup: false, qos: QoS::ExactlyOnce, @@ -19,15 +156,15 @@ fn publish_packet_1() -> Packet { message_expiry_interval: None, topic_alias: None, response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), - subscription_identifier: vec![1], + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], user_properties: vec![], content_type: None, }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), }) } -fn publish_packet_2() -> Packet { +pub fn publish_packet_2() -> Packet { Packet::Publish(Publish { dup: true, qos: QoS::ExactlyOnce, @@ -39,15 +176,15 @@ fn publish_packet_2() -> Packet { message_expiry_interval: Some(3600), topic_alias: Some(1), response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), - subscription_identifier: vec![1], + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], user_properties: vec![], content_type: None, }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), }) } -fn publish_packet_3() -> Packet { +pub fn publish_packet_3() -> Packet { Packet::Publish(Publish { dup: true, qos: QoS::AtLeastOnce, @@ -59,15 +196,15 @@ fn publish_packet_3() -> Packet { message_expiry_interval: Some(3600), topic_alias: None, response_topic: Some("Please respond here thank you".into()), - correlation_data: Some(Bytes::from_static(b"5420874")), - subscription_identifier: vec![], + correlation_data: Some(b"5420874".to_vec()), + subscription_identifiers: vec![], user_properties: vec![("blabla".into(), "another blabla".into())], content_type: None, }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), }) } -fn publish_packet_4() -> Packet { +pub fn publish_packet_4() -> Packet { Packet::Publish(Publish { dup: true, qos: QoS::AtLeastOnce, @@ -79,18 +216,18 @@ fn publish_packet_4() -> Packet { message_expiry_interval: Some(3600), topic_alias: Some(1), response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), - subscription_identifier: vec![1], + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], user_properties: vec![], content_type: Some("Garbage".into()), }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), // payload: Bytes::from_iter(b"abcdefg".repeat(500)), }) } pub fn create_subscribe_packet(packet_identifier: u16) -> Packet { - let subscription: Subscription = "test/topic".into(); + let subscription: SubscribeTopics = "test/topic".into(); let sub = Subscribe::new(packet_identifier, subscription.0); Packet::Subscribe(sub) } @@ -112,12 +249,33 @@ pub fn create_publish_packet(qos: QoS, dup: bool, retain: bool, packet_identifie message_expiry_interval: Some(3600), topic_alias: Some(1), response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), - subscription_identifier: vec![1], + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], user_properties: vec![], content_type: Some("Garbage".into()), }, - payload: Bytes::from_iter(b"testabcbba==asdasdasdasdasd".repeat(500)), + payload: b"testabcbba==asdasdasdasdasd".repeat(500).to_vec(), + }) +} + +pub fn create_empty_publish_packet() -> Packet { + Packet::Publish(Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test/#".into(), + packet_identifier: None, + publish_properties: PublishProperties { + payload_format_indicator: None, + message_expiry_interval: Some(3600), + topic_alias: Some(1), + response_topic: None, + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], + user_properties: vec![], + content_type: Some("Garbage".into()), + }, + payload: vec![], }) } @@ -143,6 +301,121 @@ pub fn create_disconnect_packet() -> Packet { }) } +pub fn suback_case() -> Packet { + let expected = SubAck { + packet_identifier: 3, + reason_codes: vec![SubAckReasonCode::GrantedQoS0, SubAckReasonCode::GrantedQoS1, SubAckReasonCode::GrantedQoS2], + properties: SubAckProperties { + user_properties: vec![(String::from("test").into(), String::from("test").into())], + subscription_identifier: Some(2000), + }, + }; + + Packet::SubAck(expected) +} + +pub fn subscribe_case() -> Packet { + let expected = Subscribe { + packet_identifier: 3, + topics: vec![("test/topic".into(), SubscriptionOptions::default())], + properties: SubscribeProperties { + user_properties: vec![(String::from("test").into(), String::from("test").into())], + subscription_identifier: Some(2000), + }, + }; + + Packet::Subscribe(expected) +} + +// return a crazy big packet +pub fn unsuback_case() -> Packet { + let expected = UnsubAck { + packet_identifier: 3, + reason_codes: vec![ + UnsubAckReasonCode::NoSubscriptionExisted, + UnsubAckReasonCode::UnspecifiedError, + UnsubAckReasonCode::ImplementationSpecificError, + ], + properties: UnsubAckProperties { + user_properties: vec![], + reason_string: None, + }, + }; + + Packet::UnsubAck(expected) +} + +pub fn unsubscribe_case() -> Packet { + let expected = Unsubscribe { + packet_identifier: 3, + topics: vec!["test/topic".into()], + properties: UnsubscribeProperties { + user_properties: vec![("written += 1;".into(), "value".into())], + }, + }; + + Packet::Unsubscribe(expected) +} + +pub fn pubrec_case() -> Packet { + let expected = PubRec { + packet_identifier: 3, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties { + reason_string: Some("test".into()), + user_properties: vec![("test5asdf".into(), "test3".into()), ("test4".into(), "test2".into())], + }, + }; + + Packet::PubRec(expected) +} + +pub fn pubcomp_case() -> Packet { + let expected = PubComp { + packet_identifier: 3, + reason_code: PubCompReasonCode::PacketIdentifierNotFound, + properties: PubCompProperties { + reason_string: Some("test".into()), + user_properties: vec![ + ("test5asdf".into(), "test3".into()), + ("test⌚5asdf".into(), "test3".into()), + ("test5asdf".into(), "test3".into()), + ("test5asdf".into(), "test3".into()), + ("test4".into(), "test2".into()), + ], + }, + }; + + Packet::PubComp(expected) +} + +pub fn pubrel_case2() -> Packet { + let expected = PubRel { + packet_identifier: 3, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: Some("test".into()), + user_properties: vec![("test5asdf".into(), "test3".repeat(10000).into()), ("test4".into(), "test2".into())], + }, + }; + + Packet::PubRel(expected) +} + +pub fn auth_case() -> Packet { + let expected = Auth { + reason_code: AuthReasonCode::ContinueAuthentication, + properties: AuthProperties { + authentication_method: Some("SomeRandomDataHere".into()), + authentication_data: Some(b"VeryRandomStuff".to_vec()), + reason_string: Some("⌚this_is_for_sure_a_test_⌚".into()), + user_properties: vec![("SureHopeThisWorks".into(), "😰".into())], + }, + }; + + Packet::Auth(expected) +} + #[rstest] #[case(create_subscribe_packet(1))] #[case(create_subscribe_packet(65335))] @@ -159,7 +432,7 @@ fn test_equal_write_read(#[case] packet: Packet) { packet.write(&mut buffer).unwrap(); - let read_packet = Packet::read_from_buffer(&mut buffer).unwrap(); + let read_packet = Packet::read(&mut buffer).unwrap(); assert_eq!(packet, read_packet); } diff --git a/mqrstt/src/tokio/mod.rs b/mqrstt/src/tokio/mod.rs index 234ec18..4aef142 100644 --- a/mqrstt/src/tokio/mod.rs +++ b/mqrstt/src/tokio/mod.rs @@ -2,91 +2,4 @@ mod stream; pub(crate) mod network; -use futures::Future; pub use network::Network; -pub use network::{NetworkReader, NetworkWriter}; - -use crate::error::ConnectionError; -use crate::packets::Packet; - -/// This empty struct is used to indicate the handling of messages goes via a mutable handler. -/// Only a single mutable reference can exist at once. -/// Thus this kind is not for concurrent message handling but for concurrent TCP read and write operations. -pub struct SequentialHandler; - -/// This empty struct is used to indicate a (tokio) task based handling of messages. -/// Per incoming message a task is spawned to call the handler. -/// -/// This kind of handler is used for both concurrent message handling and concurrent TCP read and write operations. -pub struct ConcurrentHandler; - -pub trait HandlerExt: Sized { - /// Should call the handler in the fashion of the handler. - /// (e.g. spawn a task if or await the handle call) - fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send; - - /// Should call the handler and await it - fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send; - - /// Should call the handler in the fashion of the handler. - /// (e.g. spawn a task if or await the handle call) - /// The reply (e.g. an ACK) to the original packet is only send when the handle call has completed - fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send - where - S: Send; -} - -impl HandlerExt for SequentialHandler { - #[inline] - fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - handler.handle(incoming_packet) - } - #[inline] - fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - handler.handle(incoming_packet) - } - fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send - where - S: Send, - { - async { - network.handler.handle(incoming_packet).await; - if let Some(reply_packet) = reply_packet { - network.to_writer_s.send(reply_packet).await?; - } - Ok(()) - } - } -} - -impl HandlerExt for ConcurrentHandler { - fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - let handler_clone = handler.clone(); - tokio::spawn(async move { - handler_clone.handle(incoming_packet).await; - }); - std::future::ready(()) - } - #[inline] - fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - handler.handle(incoming_packet) - } - - fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send - where - S: Send, - { - let handler_clone = network.handler.clone(); - let write_channel_clone = network.to_writer_s.clone(); - - network.join_set.spawn(async move { - handler_clone.handle(incoming_packet).await; - if let Some(reply_packet) = reply_packet { - write_channel_clone.send(reply_packet).await?; - } - Ok(()) - }); - - std::future::ready(Ok(())) - } -} diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index 8a2be8d..6b691cb 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -1,31 +1,27 @@ -use async_channel::{Receiver, Sender}; -use tokio::task::JoinSet; +use async_channel::Receiver; use std::marker::PhantomData; -use std::sync::atomic::AtomicBool; + use std::sync::Arc; use std::time::{Duration, Instant}; use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::ConnectionError; -use crate::packets::error::ReadBytes; -use crate::packets::reason_codes::DisconnectReasonCode; +use crate::packets::DisconnectReasonCode; use crate::packets::{Disconnect, Packet, PacketType}; -use crate::{AsyncEventHandlerMut, NetworkStatus, StateHandler}; +use crate::{AsyncEventHandler, NetworkStatus, StateHandler}; -use super::stream::Stream; -use super::{HandlerExt, SequentialHandler}; +use super::stream::StreamExt; /// [`Network`] reads and writes to the network based on tokios [`::tokio::io::AsyncReadExt`] [`::tokio::io::AsyncWriteExt`]. /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. /// The most import thing to remember is that you have to provide a new stream after the previous has failed. /// (i.e. you need to reconnect after any expected or unexpected disconnect). -pub struct Network { - handler_helper: PhantomData, +pub struct Network { handler: PhantomData, - network: Option>, + network: Option, /// Options of the current mqtt connection options: ConnectOptions, @@ -35,10 +31,9 @@ pub struct Network { to_network_r: Receiver, } -impl Network { - pub fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { +impl Network { + pub(crate) fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { Self { - handler_helper: PhantomData, handler: PhantomData, network: None, @@ -54,15 +49,16 @@ impl Network { } } -/// Tokio impl -impl Network +impl Network where - N: HandlerExt, + H: AsyncEventHandler, S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// Initializes an MQTT connection with the provided configuration an stream - pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { - let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; + /// + /// It is recommended to use a buffered stream. [`tokio::io::BufStream`] could be used to easily buffer both read and write. + pub async fn connect(&mut self, mut stream: S, handler: &mut H) -> Result<(), ConnectionError> { + let conn_ack = stream.connect(&self.options).await?; self.last_network_action = Instant::now(); if let Some(keep_alive_interval) = conn_ack.connack_properties.server_keep_alive { @@ -73,22 +69,21 @@ where } let packets = self.state_handler.handle_incoming_connack(&conn_ack)?; - N::call_handler_await(handler, Packet::ConnAck(conn_ack)).await; - if let Some(mut packets) = packets { - network.write_all(&mut packets).await?; + handler.handle(Packet::ConnAck(conn_ack)).await; + if let Some(packets) = packets { + stream.write_packets(&packets).await?; self.last_network_action = Instant::now(); } - self.network = Some(network); + self.network = Some(stream); Ok(()) } } -impl Network +impl Network where - H: AsyncEventHandlerMut, - SequentialHandler: HandlerExt, + H: AsyncEventHandler, S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// A single call to run will perform one of three tasks: @@ -103,13 +98,10 @@ where return Err(ConnectionError::NoNetwork); } - match self.tokio_select(handler).await { - otherwise => { - self.network = None; + let result = self.tokio_select(handler).await; + self.network = None; - otherwise - } - } + result } async fn tokio_select(&mut self, handler: &mut H) -> Result { @@ -119,13 +111,11 @@ where last_network_action, perform_keep_alive, to_network_r, - handler_helper: _, handler: _, state_handler, } = self; let mut await_pingresp = None; - let mut outgoing_packet_buffer = Vec::new(); loop { let sleep; @@ -137,45 +127,48 @@ where if let Some(stream) = network { tokio::select! { - res = stream.read_bytes() => { - res?; - loop{ - let packet = match stream.parse_message().await { - Err(ReadBytes::Err(err)) => return Err(err), - Err(ReadBytes::InsufficientBytes(_)) => break, - Ok(packet) => packet, - }; - match packet{ - Packet::PingResp => { - SequentialHandler::call_handler_await(handler, packet).await; - await_pingresp = None; - }, - Packet::Disconnect(_) => { - SequentialHandler::call_handler_await(handler, packet).await; - return Ok(NetworkStatus::IncomingDisconnect); - } - packet => { - match state_handler.handle_incoming_packet(&packet)? { - (maybe_reply_packet, true) => { - SequentialHandler::call_handler_await(handler, packet).await; - if let Some(reply_packet) = maybe_reply_packet { - outgoing_packet_buffer.push(reply_packet); - } - }, - (Some(reply_packet), false) => { - outgoing_packet_buffer.push(reply_packet); - }, - (None, false) => (), - } + res = stream.read_packet() => { + #[cfg(feature = "logs")] + tracing::trace!("Received incoming packet {:?}", &res); + + let packet = res?; + match packet{ + Packet::PingResp => { + handler.handle(packet).await; + await_pingresp = None; + }, + Packet::Disconnect(_) => { + handler.handle(packet).await; + return Ok(NetworkStatus::IncomingDisconnect); + } + packet => { + match state_handler.handle_incoming_packet(&packet)? { + (maybe_reply_packet, true) => { + handler.handle(packet).await; + if let Some(reply_packet) = maybe_reply_packet { + stream.write_packet(&reply_packet).await?; + *last_network_action = Instant::now(); + } + }, + (Some(reply_packet), false) => { + stream.write_packet(&reply_packet).await?; + *last_network_action = Instant::now(); + }, + (None, false) => (), } } - stream.write_all(&mut outgoing_packet_buffer).await?; - *last_network_action = Instant::now(); } }, outgoing = to_network_r.recv() => { + #[cfg(feature = "logs")] + tracing::trace!("Received outgoing item {:?}", &outgoing); + let packet = outgoing?; - stream.write(&packet).await?; + + #[cfg(feature = "logs")] + tracing::trace!("Sending packet {}", packet); + + stream.write_packet(&packet).await?; let disconnect = packet.packet_type() == PacketType::Disconnect; state_handler.handle_outgoing_packet(packet)?; @@ -188,13 +181,13 @@ where }, _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { let packet = Packet::PingReq; - stream.write(&packet).await?; + stream.write_packet(&packet).await?; *last_network_action = Instant::now(); await_pingresp = Some(Instant::now()); }, _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - stream.write(&Packet::Disconnect(disconnect)).await?; + stream.write_packet(&Packet::Disconnect(disconnect)).await?; return Ok(NetworkStatus::KeepAliveTimeout); } } @@ -203,228 +196,86 @@ where } } } -} - -impl Network -where - S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, -{ - /// Creates both read and write tasks to run this them in parallel. - /// If you want to run concurrently (not parallel) the [`Self::run`] method is a better aproach! - pub fn split(&mut self, handler: H) -> Result<(NetworkReader, NetworkWriter), ConnectionError> { - if self.network.is_none() { - return Err(ConnectionError::NoNetwork)?; - } - - match self.network.take() { - Some(network) => { - let (read_stream, write_stream) = network.split(); - let run_signal = Arc::new(AtomicBool::new(true)); - let (to_writer_s, to_writer_r) = async_channel::bounded(100); - let await_pingresp_atomic = Arc::new(AtomicBool::new(false)); - - let read_network = NetworkReader { - run_signal: run_signal.clone(), - handler_helper: PhantomData, - handler: handler, - read_stream, - await_pingresp_atomic: await_pingresp_atomic.clone(), - state_handler: self.state_handler.clone(), - to_writer_s, - join_set: JoinSet::new(), - }; - - let write_network = NetworkWriter { - run_signal: run_signal.clone(), - write_stream, - keep_alive_interval: self.options.keep_alive_interval, - last_network_action: self.last_network_action, - await_pingresp_bool: await_pingresp_atomic.clone(), - await_pingresp_time: None, - perform_keep_alive: self.perform_keep_alive, - state_handler: self.state_handler.clone(), - to_writer_r: to_writer_r, - to_network_r: self.to_network_r.clone(), - }; - - Ok((read_network, write_network)) - } - None => Err(ConnectionError::NoNetwork), - } - } -} - -pub struct NetworkReader { - pub(crate) run_signal: Arc, - - pub(crate) handler_helper: PhantomData, - pub handler: H, - - pub(crate) read_stream: super::stream::read_half::ReadStream, - pub(crate) await_pingresp_atomic: Arc, - pub(crate) state_handler: Arc, - pub(crate) to_writer_s: Sender, - pub(crate) join_set: JoinSet>, -} - -impl NetworkReader -where - N: HandlerExt, - S: tokio::io::AsyncReadExt + Sized + Unpin + Send + 'static, -{ - /// Runs the read half of the mqtt connection. - /// Continuously loops until disconnect or error. - /// - /// # Return - /// - Ok(None) in the case that the write task requested shutdown. - /// - Ok(Some(reason)) in the case that this task initiates a shutdown. - /// - Err in the case of IO, or protocol errors. - pub async fn run(mut self) -> (Result, H) { - let ret = self.read().await; - self.run_signal.store(false, std::sync::atomic::Ordering::Release); - while let Some(_) = self.join_set.join_next().await { - () - } - (ret, self.handler) - } - async fn read(&mut self) -> Result { - while self.run_signal.load(std::sync::atomic::Ordering::Acquire) { - let _ = self.read_stream.read_bytes().await?; - loop { - let packet = match self.read_stream.parse_message() { - Err(ReadBytes::Err(err)) => return Err(err), - Err(ReadBytes::InsufficientBytes(_)) => { - break; - } - Ok(packet) => packet, - }; - - match packet { - Packet::PingResp => { - N::call_handler(&mut self.handler, packet).await; - #[cfg(feature = "logs")] - if !self.await_pingresp_atomic.fetch_and(false, std::sync::atomic::Ordering::SeqCst) { - tracing::warn!("Received PingResp but did not expect it"); - } - #[cfg(not(feature = "logs"))] - self.await_pingresp_atomic.store(false, std::sync::atomic::Ordering::SeqCst); - } - Packet::Disconnect(_) => { - N::call_handler(&mut self.handler, packet).await; - return Ok(NetworkStatus::IncomingDisconnect); - } - Packet::ConnAck(conn_ack) => { - if let Some(retransmit_packets) = self.state_handler.handle_incoming_connack(&conn_ack)? { - for packet in retransmit_packets.into_iter() { - self.to_writer_s.send(packet).await?; - } - } - N::call_handler(&mut self.handler, Packet::ConnAck(conn_ack)).await; - } - packet => match self.state_handler.handle_incoming_packet(&packet)? { - (maybe_reply_packet, true) => { - N::call_handler_with_reply(self, packet, maybe_reply_packet).await?; - } - (Some(reply_packet), false) => { - self.to_writer_s.send(reply_packet).await?; - } - (None, false) => (), - }, - } - } - } - Ok(NetworkStatus::ShutdownSignal) - } -} - -pub struct NetworkWriter { - run_signal: Arc, - - write_stream: super::stream::write_half::WriteStream, - - keep_alive_interval: Duration, - - last_network_action: Instant, - await_pingresp_bool: Arc, - await_pingresp_time: Option, - perform_keep_alive: bool, - - state_handler: Arc, - - to_writer_r: Receiver, - to_network_r: Receiver, -} -impl NetworkWriter -where - S: tokio::io::AsyncWriteExt + Sized + Unpin, -{ - /// Runs the read half of the mqtt connection. - /// Continuously loops until disconnect or error. - /// - /// # Return - /// - Ok(None) in the case that the read task requested shutdown - /// - Ok(Some(reason)) in the case that this task initiates a shutdown - /// - Err in the case of IO, or protocol errors. - pub async fn run(mut self) -> Result { - let ret = self.write().await; - self.run_signal.store(false, std::sync::atomic::Ordering::Release); - ret - } - async fn write(&mut self) -> Result { - while self.run_signal.load(std::sync::atomic::Ordering::Acquire) { - if self.await_pingresp_time.is_some() && !self.await_pingresp_bool.load(std::sync::atomic::Ordering::Acquire) { - self.await_pingresp_time = None; - } - - let sleep; - if let Some(instant) = &self.await_pingresp_time { - sleep = *instant + self.keep_alive_interval - Instant::now(); - } else { - sleep = self.last_network_action + self.keep_alive_interval - Instant::now(); - }; - tokio::select! { - outgoing = self.to_network_r.recv() => { - let packet = outgoing?; - self.write_stream.write(&packet).await?; - - let disconnect = packet.packet_type() == PacketType::Disconnect; - - self.state_handler.handle_outgoing_packet(packet)?; - self.last_network_action = Instant::now(); - - if disconnect{ - return Ok(NetworkStatus::OutgoingDisconnect); - } - }, - from_reader = self.to_writer_r.recv() => { - let packet = from_reader?; - self.write_stream.write(&packet).await?; - match packet { - foo @ (Packet::Publish(_) | Packet::Subscribe(_) | Packet::Unsubscribe(_) | Packet::Disconnect(_)) => { - self.state_handler.handle_outgoing_packet(foo)?; - }, - _ => (), - } - self.last_network_action = Instant::now(); - }, - _ = tokio::time::sleep(sleep), if self.await_pingresp_time.is_none() && self.perform_keep_alive => { - let packet = Packet::PingReq; - self.write_stream.write(&packet).await?; - self.await_pingresp_bool.store(true, std::sync::atomic::Ordering::SeqCst); - self.last_network_action = Instant::now(); - self.await_pingresp_time = Some(Instant::now()); - }, - _ = tokio::time::sleep(sleep), if self.await_pingresp_time.is_some() => { - self.await_pingresp_time = None; - if self.await_pingresp_bool.load(std::sync::atomic::Ordering::SeqCst){ - let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - self.write_stream.write(&Packet::Disconnect(disconnect)).await?; - return Ok(NetworkStatus::KeepAliveTimeout); - } - } - } - } - Ok(NetworkStatus::ShutdownSignal) - } + // async fn concurrent_tokio_select(&mut self, handler: &mut H) -> Result { + // let Network { + // network, + // options, + // last_network_action, + // perform_keep_alive, + // to_network_r, + // handler: _, + // state_handler, + // } = self; + + // let mut await_pingresp = None; + + // loop { + // let sleep; + // if let Some(instant) = await_pingresp { + // sleep = instant + options.get_keep_alive_interval() - Instant::now(); + // } else { + // sleep = *last_network_action + options.get_keep_alive_interval() - Instant::now(); + // } + + // if let Some(stream) = network { + // tokio::select! { + // res = stream.read_packet() => { + // let packet = res?; + // match packet{ + // Packet::PingResp => { + // handler.handle(packet).await; + // await_pingresp = None; + // }, + // Packet::Disconnect(_) => { + // handler.handle(packet).await; + // return Ok(NetworkStatus::IncomingDisconnect); + // } + // packet => { + // match state_handler.handle_incoming_packet(&packet)? { + // (maybe_reply_packet, true) => { + // handler.handle(packet).await; + // if let Some(reply_packet) = maybe_reply_packet { + // stream.write_packet(&reply_packet).await?; + // *last_network_action = Instant::now(); + // } + // }, + // (Some(reply_packet), false) => { + // stream.write_packet(&reply_packet).await?; + // *last_network_action = Instant::now(); + // }, + // (None, false) => (), + // } + // } + // } + // }, + // outgoing = to_network_r.recv() => { + // let packet = outgoing?; + // stream.write_packet(&packet).await?; + // let disconnect = packet.packet_type() == PacketType::Disconnect; + + // state_handler.handle_outgoing_packet(packet)?; + // *last_network_action = Instant::now(); + + // if disconnect{ + // return Ok(NetworkStatus::OutgoingDisconnect); + // } + // }, + // _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { + // let packet = Packet::PingReq; + // stream.write_packet(&packet).await?; + // *last_network_action = Instant::now(); + // await_pingresp = Some(Instant::now()); + // }, + // _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { + // let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; + // stream.write_packet(&Packet::Disconnect(disconnect)).await?; + // return Ok(NetworkStatus::KeepAliveTimeout); + // } + // } + // } else { + // return Err(ConnectionError::NoNetwork); + // } + // } + // } } diff --git a/mqrstt/src/tokio/stream.rs b/mqrstt/src/tokio/stream.rs new file mode 100644 index 0000000..9a720ad --- /dev/null +++ b/mqrstt/src/tokio/stream.rs @@ -0,0 +1,78 @@ +use tokio::io::AsyncWriteExt; + +#[cfg(feature = "logs")] +use tracing::trace; + +use crate::packets::ConnAck; +use crate::packets::{ConnAckReasonCode, Packet}; +use crate::{connect_options::ConnectOptions, error::ConnectionError}; + +pub(crate) trait StreamExt { + fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future>; + fn read_packet(&mut self) -> impl std::future::Future>; + fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future>; + fn write_packets(&mut self, packets: &[Packet]) -> impl std::future::Future>; + fn flush_packets(&mut self) -> impl std::future::Future>; +} + +impl StreamExt for S +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, +{ + async fn connect(&mut self, options: &ConnectOptions) -> Result { + let connect = options.create_connect_from_options(); + + self.write_packet(&connect).await?; + + let packet = Packet::async_read(self).await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + #[cfg(feature = "logs")] + trace!("Connected to server"); + Ok(con) + } else { + Err(ConnectionError::ConnectionRefused(con.reason_code)) + } + } else { + Err(ConnectionError::NotConnAck(packet)) + } + } + + async fn read_packet(&mut self) -> Result { + Ok(Packet::async_read(self).await?) + } + + async fn write_packet(&mut self, packet: &Packet) -> Result<(), ConnectionError> { + match packet.async_write(self).await { + Ok(_) => (), + Err(err) => { + return match err { + crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), + crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), + } + } + } + + #[cfg(feature = "logs")] + trace!("Sending packet {}", packet); + + self.flush().await?; + // self.flush_packets().await?; + + Ok(()) + } + + async fn write_packets(&mut self, packets: &[Packet]) -> Result<(), ConnectionError> { + for packet in packets { + let _ = packet.async_write(self).await; + #[cfg(feature = "logs")] + trace!("Sending packet {}", packet); + } + self.flush_packets().await?; + Ok(()) + } + + fn flush_packets(&mut self) -> impl std::future::Future> { + tokio::io::AsyncWriteExt::flush(self) + } +} diff --git a/mqrstt/src/tokio/stream/mod.rs b/mqrstt/src/tokio/stream/mod.rs deleted file mode 100644 index ad928d5..0000000 --- a/mqrstt/src/tokio/stream/mod.rs +++ /dev/null @@ -1,176 +0,0 @@ -pub mod read_half; -pub mod write_half; - -use std::io::{self, Error, ErrorKind}; - -use bytes::{Buf, BytesMut}; - -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -#[cfg(feature = "logs")] -use tracing::trace; - -use crate::packets::ConnAck; -use crate::packets::{ - error::ReadBytes, - reason_codes::ConnAckReasonCode, - {FixedHeader, Packet}, -}; -use crate::{connect_options::ConnectOptions, error::ConnectionError}; - -use self::read_half::ReadStream; -use self::write_half::WriteStream; - -#[derive(Debug)] -pub struct Stream { - pub stream: S, - - /// Input buffer - const_buffer: [u8; 4096], - - /// Write buffer - read_buffer: BytesMut, - - /// Write buffer - write_buffer: BytesMut, -} - -impl Stream { - pub async fn parse_message(&mut self) -> Result> { - let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; - - if header.remaining_length + header_length > self.read_buffer.len() { - return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); - } - - self.read_buffer.advance(header_length); - - let buf = self.read_buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - - #[cfg(feature = "logs")] - trace!("Read packet from network {}", read_packet); - - Ok(read_packet) - } -} - -impl Stream -where - S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, -{ - pub fn split(self) -> (ReadStream, WriteStream) { - let Self { - stream, - const_buffer, - read_buffer, - write_buffer, - } = self; - - let (read_stream, write_stream) = tokio::io::split(stream); - - (ReadStream::new(read_stream, const_buffer, read_buffer), WriteStream::new(write_stream, write_buffer)) - } - - pub async fn connect(options: &ConnectOptions, stream: S) -> Result<(Self, ConnAck), ConnectionError> { - let mut s = Self { - stream, - const_buffer: [0; 4096], - read_buffer: BytesMut::new(), - write_buffer: BytesMut::new(), - }; - - let connect = options.create_connect_from_options(); - - s.write(&connect).await?; - - let packet = s.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((s, con)) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.read_buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_required_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - if header_length + header.remaining_length > self.read_buffer.len() { - self.read_required_bytes(header.remaining_length - self.read_buffer.len()).await?; - } - - self.read_buffer.advance(header_length); - - let buf = self.read_buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - pub async fn read_bytes(&mut self) -> io::Result { - let read = self.stream.read(&mut self.const_buffer).await?; - if read == 0 { - Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) - } else { - self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); - Ok(read) - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - let read = self.read_bytes().await?; - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } - - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - packet.write(&mut self.write_buffer)?; - - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); - - self.stream.write_all(&self.write_buffer[..]).await?; - self.stream.flush().await?; - self.write_buffer.clear(); - Ok(()) - } - - pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { - let writes = packets.drain(0..).map(|packet| { - packet.write(&mut self.write_buffer)?; - - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); - - Ok::<(), ConnectionError>(()) - }); - - for write in writes { - write?; - } - - self.stream.write_all(&self.write_buffer[..]).await?; - self.stream.flush().await?; - self.write_buffer.clear(); - Ok(()) - } -} diff --git a/mqrstt/src/tokio/stream/read_half.rs b/mqrstt/src/tokio/stream/read_half.rs deleted file mode 100644 index 0d5be0b..0000000 --- a/mqrstt/src/tokio/stream/read_half.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::io; - -use bytes::{Buf, BytesMut}; -use tokio::io::{AsyncReadExt, ReadHalf}; - -use crate::{ - error::ConnectionError, - packets::{error::ReadBytes, FixedHeader, Packet}, -}; - -#[cfg(feature = "logs")] -use tracing::trace; - -#[derive(Debug)] -pub struct ReadStream { - stream: ReadHalf, - - /// Input buffer - const_buffer: [u8; 4096], - - /// Write buffer - read_buffer: BytesMut, -} - -impl ReadStream -where - S: tokio::io::AsyncRead + Sized + Unpin, -{ - pub fn new(stream: ReadHalf, const_buffer: [u8; 4096], read_buffer: BytesMut) -> Self { - Self { stream, const_buffer, read_buffer } - } - - pub fn parse_message(&mut self) -> Result> { - let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; - - if header.remaining_length + header_length > self.read_buffer.len() { - return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); - } - - self.read_buffer.advance(header_length); - - let buf = self.read_buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - - #[cfg(feature = "logs")] - trace!("Read packet from network {}", read_packet); - - Ok(read_packet) - } - - pub async fn read_bytes(&mut self) -> io::Result { - let read = self.stream.read(&mut self.const_buffer).await?; - if read == 0 { - Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) - } else { - self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); - Ok(read) - } - } -} diff --git a/mqrstt/src/tokio/stream/write_half.rs b/mqrstt/src/tokio/stream/write_half.rs deleted file mode 100644 index 9bc5fb4..0000000 --- a/mqrstt/src/tokio/stream/write_half.rs +++ /dev/null @@ -1,38 +0,0 @@ -use bytes::BytesMut; -use tokio::io::{AsyncWriteExt, WriteHalf}; - -use crate::{error::ConnectionError, packets::Packet}; - -#[cfg(feature = "logs")] -use tracing::trace; - -#[derive(Debug)] -pub struct WriteStream { - pub stream: WriteHalf, - - /// Write buffer - write_buffer: BytesMut, -} - -impl WriteStream { - pub fn new(stream: WriteHalf, write_buffer: BytesMut) -> Self { - Self { stream, write_buffer } - } -} - -impl WriteStream -where - S: tokio::io::AsyncWrite + Sized + Unpin, -{ - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - packet.write(&mut self.write_buffer)?; - - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); - - self.stream.write_all(&self.write_buffer[..]).await?; - self.stream.flush().await?; - self.write_buffer.clear(); - Ok(()) - } -} diff --git a/mqrstt/src/util/constants.rs b/mqrstt/src/util/constants.rs index 023c60d..5ea5612 100644 --- a/mqrstt/src/util/constants.rs +++ b/mqrstt/src/util/constants.rs @@ -1,3 +1,3 @@ -pub const DEFAULT_RECEIVE_MAXIMUM: u16 = 65535; -pub const MAXIMUM_PACKET_SIZE: u32 = 268435455; -pub const MAXIMUM_TOPIC_SIZE: usize = 65535; +pub(crate) const DEFAULT_RECEIVE_MAXIMUM: u16 = 65535; +pub(crate) const MAXIMUM_PACKET_SIZE: u32 = 268435455; +pub(crate) const MAXIMUM_TOPIC_SIZE: usize = 65535; diff --git a/rust-toolchain b/rust-toolchain deleted file mode 100644 index 982f51e..0000000 --- a/rust-toolchain +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "1.75.0" \ No newline at end of file diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..d519a31 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly" +#channel = "1.82.0" \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml index c9ce889..f494da5 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,4 +1,4 @@ -# unstable_features = true -# brace_style = "PreferSameLine" -# control_brace_style = "ClosingNextLine" +# unstable_features = true +# brace_style = "PreferSameLine" +# control_brace_style = "ClosingNextLine" max_width = 200 \ No newline at end of file