From 7cd0424e8a3302d5ef90e7d4c9eb809056bc0ff8 Mon Sep 17 00:00:00 2001 From: Gunnar <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 10 Jan 2023 11:36:02 +0100 Subject: [PATCH] Changed connections to allow user to provide the stream (#4) * Added keep alive to tokio network * Removed nested packets mod * Remove nightly toolchain and async_fn_in_trait feature * Removed more code dependent on async_fn_in_trait * Adjusted smol stream to return pingresp and disconnect indicators too. * Removed atomic_disconnect and added return info Removed atomic disconnect from event handler and added enum to return status of the event handler. * Removed commented section * Added return info and keep alive * Added async_trait for handler * Removed stages, they don't work atm * Adjusted license to MPL * Removed warn from tokio_network * Fixed tokio example test in lib * Always perform reset in the needed cases * Adjust lib tests * fmt and clippy * Adjusted trait bounds * Adjusted doc example * Fix lib test cases * rustfmt * Remove vscode dir * Add license file * Added codecov badge and adjusted version * Fix license cargo deny --- .vscode/settings.json | 9 - Cargo.lock | 85 +-- Cargo.toml | 39 +- LICENSE | 373 +++++++++++ README.md | 27 +- deny.toml | 17 +- rust-toolchain.toml | 2 - rustfmt.toml | 3 + src/available_packet_ids.rs | 32 +- src/client.rs | 11 +- src/connect_options.rs | 93 +-- src/connections/async_native_tls.rs | 236 ------- src/connections/async_rustls.rs | 297 --------- src/connections/mod.rs | 81 +-- src/connections/smol_stream.rs | 176 +++++ src/connections/smol_tcp.rs | 220 ------- src/connections/tokio_rustls.rs | 290 --------- src/connections/tokio_stream.rs | 174 +++++ src/connections/tokio_tcp.rs | 223 ------- src/connections/transport.rs | 31 - src/connections/util.rs | 74 --- src/error.rs | 43 +- src/event_handler.rs | 975 ++++++++++++++-------------- src/lib.rs | 580 ++++++++++++----- src/network.rs | 121 ---- src/packets/auth.rs | 3 +- src/packets/connack.rs | 3 +- src/packets/connect.rs | 12 +- src/packets/disconnect.rs | 9 +- src/packets/mod.rs | 546 +++++++++++++++- src/packets/packets.rs | 531 --------------- src/packets/puback.rs | 24 +- src/packets/pubcomp.rs | 30 +- src/packets/publish.rs | 6 +- src/packets/pubrec.rs | 18 +- src/packets/pubrel.rs | 34 +- src/packets/suback.rs | 6 +- src/packets/subscribe.rs | 8 +- src/packets/unsuback.rs | 6 +- src/packets/unsubscribe.rs | 3 +- src/smol_network.rs | 166 +++++ src/state.rs | 1 - src/tests/handler_tests.rs | 1 - src/tests/mod.rs | 2 - src/tests/resources/test_packets.rs | 11 +- src/tests/stages.rs | 69 -- src/tokio_network.rs | 156 +++++ src/util/mod.rs | 1 + src/util/timeout.rs | 48 +- src/util/tls.rs | 75 +++ 50 files changed, 2803 insertions(+), 3178 deletions(-) delete mode 100644 .vscode/settings.json create mode 100644 LICENSE delete mode 100644 rust-toolchain.toml create mode 100644 rustfmt.toml delete mode 100644 src/connections/async_native_tls.rs delete mode 100644 src/connections/async_rustls.rs create mode 100644 src/connections/smol_stream.rs delete mode 100644 src/connections/smol_tcp.rs delete mode 100644 src/connections/tokio_rustls.rs create mode 100644 src/connections/tokio_stream.rs delete mode 100644 src/connections/tokio_tcp.rs delete mode 100644 src/connections/transport.rs delete mode 100644 src/connections/util.rs delete mode 100644 src/network.rs delete mode 100644 src/packets/packets.rs create mode 100644 src/smol_network.rs delete mode 100644 src/tests/handler_tests.rs delete mode 100644 src/tests/stages.rs create mode 100644 src/tokio_network.rs create mode 100644 src/util/tls.rs diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index cf4cb03..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "cSpell.words": [ - "acked", - "QUIC", - "runtimes", - "smol", - "unacked" - ] -} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 0689cb2..a76ad69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -125,6 +125,17 @@ version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a40729d2133846d9ed0ea60a8b9541bccddab49cd30f0715a1da672fe9a2524" +[[package]] +name = "async-trait" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "705339e0e4a9690e2908d2b3d049d85682cf19fbd5782494498fbf7003a6a282" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.0.0" @@ -149,18 +160,6 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" -[[package]] -name = "bitvec" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] - [[package]] name = "blocking" version = "1.3.0" @@ -248,12 +247,6 @@ dependencies = [ "instant", ] -[[package]] -name = "funty" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" - [[package]] name = "futures" version = "0.3.25" @@ -278,17 +271,6 @@ dependencies = [ "futures-sink", ] -[[package]] -name = "futures-concurrency" -version = "7.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a740c32e1bde284ce2f51df98abd4fa38e9e539670443c111211777e3ab09927" -dependencies = [ - "bitvec", - "futures-core", - "pin-project", -] - [[package]] name = "futures-core" version = "0.3.25" @@ -434,15 +416,15 @@ dependencies = [ [[package]] name = "mqrstt" -version = "0.1.0" +version = "0.1.1" dependencies = [ "async-channel", "async-mutex", "async-rustls", + "async-trait", "bitflags", "bytes", "futures", - "futures-concurrency", "pretty_assertions", "rustls", "rustls-pemfile", @@ -502,26 +484,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72" -[[package]] -name = "pin-project" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.9" @@ -578,12 +540,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "radium" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" - [[package]] name = "regex" version = "1.7.0" @@ -741,12 +697,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - [[package]] name = "thiserror" version = "1.0.37" @@ -1070,15 +1020,6 @@ version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" -[[package]] -name = "wyz" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" -dependencies = [ - "tap", -] - [[package]] name = "yansi" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 8b1ba1c..7f409aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,23 +1,23 @@ [package] name = "mqrstt" -version = "0.1.0" +version = "0.1.1" +homepage = "https://github.com/GunnarMorrigan/mqrstt" +repository = "https://github.com/GunnarMorrigan/mqrstt" +categories = ["network-programming"] +readme = "README.md" edition = "2021" -license = "Apache-2.0" +license = "MPL-2.0" +keywords = [ "MQTT", "IoT", "MQTTv5", "messaging", "client" ] +description = "Pure rust MQTTv5 client implementation for Smol, Tokio and soon sync too." # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["smol", "tcp"] +default = ["smol", "tokio"] tokio = ["dep:tokio"] smol = ["dep:smol"] -tcp = [] -tokio-rustls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pemfile", "dep:webpki"] -smol-rustls = ["dep:async-rustls", "dep:rustls", "dep:rustls-pemfile", "dep:webpki"] - -# If in the future we only provide an MQTT packet handler to make it async, sync and runtime agnostic -# then this is not needed # quic = ["dep:quinn"] -# native-tls = ["dep:async-native-tls"] + [dependencies] # Packets @@ -27,32 +27,29 @@ bitflags = "1.3.2" # Errors thiserror = "1.0.37" tracing = "0.1.37" -tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } async-channel = "1.8.0" async-mutex = "1.4.0" -futures-concurrency = "7.0.0" futures = { version = "0.3.25", default-features = false, features = ["std", "async-await"] } - -# Needed for all TLS -rustls = { version = "0.20.7", optional = true } -rustls-pemfile = { version = "1.0.1", optional = true } -webpki = { version = "0.22.0", optional = true } +async-trait = "0.1.61" # quic feature flag # quinn = {version = "0.9.0", optional = true } # tokio feature flag tokio = { version = "1.21", features = ["macros", "io-util", "net", "time"], optional = true } -tokio-rustls = { version = "0.23.4", optional = true } # smol feature flag smol = { version = "1.3.0", optional = true } -async-rustls = { version = "0.3.0", optional = true } -#async-native-tls = { version = "0.4.0", optional = true } [dev-dependencies] pretty_assertions = "1.3.0" tokio = { version = "1.21", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } smol = { version = "1.3.0" } -tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } \ No newline at end of file +tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } + +rustls = { version = "0.20.7" } +rustls-pemfile = { version = "1.0.1" } +webpki = { version = "0.22.0" } +async-rustls = { version = "0.3.0" } +tokio-rustls = "0.23.4" \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ee6256c --- /dev/null +++ b/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at https://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/README.md b/README.md index 89f320c..2b3f334 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,25 @@ -# mqrstt -Async MQTT client over TCP and QUIC +
+ +# `📟 mqrstt` + +[![Crates.io](https://img.shields.io/crates/v/mqrstt.svg)](https://crates.io/crates/mqrstt) +[![Docs](https://docs.rs/mqrstt/badge.svg)](https://docs.rs/mqrstt) +[![dependency status](https://deps.rs/repo/github/GunnarMorrigan/mqrstt/status.svg)](https://deps.rs/repo/github/GunnarMorrigan/mqrstt) +[![codecov](https://codecov.io/github/GunnarMorrigan/mqrstt/branch/main/graph/badge.svg?token=YSZFYQ063Y)](https://codecov.io/github/GunnarMorrigan/mqrstt) + +`mqrstt` is an MQTTv5 client implementation that follows the [sans-io](https://sans-io.readthedocs.io/) approach. + +
+ + +## License + +Licensed under + +* Mozilla Public License, Version 2.0, ([MPL-2.0](https://choosealicense.com/licenses/mpl-2.0/) + +## Contribution + +Unless you explicitly state otherwise, any contribution intentionally +submitted for inclusion in the work by you, shall be licensed under MPL-2.0, without any additional terms or +conditions. \ No newline at end of file diff --git a/deny.toml b/deny.toml index 33fd274..b24eaba 100644 --- a/deny.toml +++ b/deny.toml @@ -8,8 +8,23 @@ unlicensed = "deny" allow-osi-fsf-free = "neither" copyleft = "deny" confidence-threshold = 0.95 -allow = ["Apache-2.0", "MIT", "BSD-3-Clause", "ISC"] +allow = ["MPL-2.0", "Apache-2.0", "MIT", "BSD-3-Clause", "ISC"] exceptions = [ { allow = ["Unicode-DFS-2016"], name = "unicode-ident" }, + { allow = ["OpenSSL"], name = "ring" } +] + +[[licenses.clarify]] +name = "ring" +expression = "MIT AND ISC AND OpenSSL" +license-files = [ + { path = "LICENSE", hash = 0xbd0eed23 } +] + +[[licenses.clarify]] +name = "webpki" +expression = "ISC" +license-files = [ + { path = "LICENSE", hash = 0x001c7e6c }, ] \ No newline at end of file diff --git a/rust-toolchain.toml b/rust-toolchain.toml deleted file mode 100644 index 271800c..0000000 --- a/rust-toolchain.toml +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "nightly" \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..e672bfc --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,3 @@ +unstable_features = true +brace_style = "PreferSameLine" +control_brace_style = "ClosingNextLine" \ No newline at end of file diff --git a/src/available_packet_ids.rs b/src/available_packet_ids.rs index 9502573..a6c2d80 100644 --- a/src/available_packet_ids.rs +++ b/src/available_packet_ids.rs @@ -1,5 +1,5 @@ use async_channel::{Receiver, Sender}; -use tracing::{debug, error}; +use tracing::error; use crate::error::MqttError; @@ -20,21 +20,21 @@ impl AvailablePacketIds { (apkid, r) } - pub fn try_mark_available(&self, pkid: u16) -> Result<(), MqttError> { - match self.sender.try_send(pkid) { - Ok(_) => { - Ok(()) - // debug!("Marked packet id as available: {}", pkid); - } - Err(err) => { - error!( - "Encountered an error while marking an packet id as available. Error: {}", - err - ); - Err(MqttError::PacketIdError(err.into_inner())) - } - } - } + // pub fn try_mark_available(&self, pkid: u16) -> Result<(), MqttError> { + // match self.sender.try_send(pkid) { + // Ok(_) => { + // Ok(()) + // // debug!("Marked packet id as available: {}", pkid); + // } + // Err(err) => { + // error!( + // "Encountered an error while marking an packet id as available. Error: {}", + // err + // ); + // Err(MqttError::PacketIdError(err.into_inner())) + // } + // } + // } pub async fn mark_available(&self, pkid: u16) -> Result<(), MqttError> { match self.sender.send(pkid).await { diff --git a/src/client.rs b/src/client.rs index c0848ff..c4bd4c2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,13 +5,10 @@ use tracing::info; use crate::{ error::ClientError, packets::{ - {Disconnect, DisconnectProperties}, - Packet, - {Publish, PublishProperties}, reason_codes::DisconnectReasonCode, + Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties}, {Subscribe, SubscribeProperties, Subscription}, {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, - QoS, }, }; @@ -115,7 +112,8 @@ impl AsyncClient { "Published message into network_packet_sender. len {}", self.to_network_s.len() ); - } else { + } + else { self.client_to_handler_s .send(Packet::Publish(publish)) .await @@ -159,7 +157,8 @@ impl AsyncClient { .send(Packet::Publish(publish)) .await .map_err(|_| ClientError::NoHandler)?; - } else { + } + else { self.client_to_handler_s .send(Packet::Publish(publish)) .await diff --git a/src/connect_options.rs b/src/connect_options.rs index 718a65d..85c560c 100644 --- a/src/connect_options.rs +++ b/src/connect_options.rs @@ -7,16 +7,6 @@ use crate::util::constants::RECEIVE_MAXIMUM_DEFAULT; #[derive(Debug, Clone)] pub struct ConnectOptions { - /// broker address that you want to connect to - pub(crate) address: String, - /// broker port - pub(crate) port: u16, - - #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))] - pub(crate) tls_config: Option, - - // /// What transport protocol to use - // transport: Transport, /// keep alive time to send pingreq to broker when the connection is idle pub keep_alive_interval_s: u64, pub connection_timeout_s: u64, @@ -28,38 +18,33 @@ pub struct ConnectOptions { pub username: Option, pub password: Option, /// request (publish, subscribe) channel capacity - channel_capacity: usize, + pub channel_capacity: usize, /// Minimum delay time between consecutive outgoing packets /// while retransmitting pending packets // TODO! IMPLEMENT THIS! - pending_throttle_s: u64, + pub pending_throttle_s: u64, - send_reason_messages: bool, + pub send_reason_messages: bool, // MQTT v5 Connect Properties: - session_expiry_interval: Option, - pub(crate) receive_maximum: Option, - maximum_packet_size: Option, - topic_alias_maximum: Option, - request_response_information: Option, - request_problem_information: Option, - user_properties: Vec<(String, String)>, - authentication_method: Option, - authentication_data: Bytes, + pub session_expiry_interval: Option, + pub receive_maximum: Option, + pub maximum_packet_size: Option, + pub topic_alias_maximum: Option, + pub request_response_information: Option, + pub request_problem_information: Option, + pub user_properties: Vec<(String, String)>, + pub authentication_method: Option, + pub authentication_data: Bytes, /// Last will that will be issued on unexpected disconnect - last_will: Option, + pub last_will: Option, } impl ConnectOptions { - pub fn new(address: String, port: u16, client_id: String) -> Self { + pub fn new(client_id: String) -> Self { Self { - address, - port, - #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))] - tls_config: None, - keep_alive_interval_s: 60, connection_timeout_s: 30, clean_session: false, @@ -83,46 +68,6 @@ impl ConnectOptions { } } - #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))] - pub(crate) fn new_with_tls_config( - address: String, - port: u16, - client_id: String, - tls_config: Option, - ) -> Self { - Self { - address, - port, - tls_config, - - keep_alive_interval_s: 60, - connection_timeout_s: 30, - clean_session: false, - client_id, - credentials: None, - channel_capacity: 100, - pending_throttle_s: 30, - send_reason_messages: false, - - session_expiry_interval: None, - receive_maximum: None, - maximum_packet_size: None, - topic_alias_maximum: None, - request_response_information: None, - request_problem_information: None, - user_properties: vec![], - authentication_method: None, - authentication_data: Bytes::new(), - last_will: None, - } - } - - pub fn set_address(&mut self, address: String) { - self.address = address - } - pub fn set_port(&mut self, port: u16) { - self.port = port - } pub fn set_keep_alive_interval_s(&mut self, keep_alive_interval_s: u64) { self.keep_alive_interval_s = keep_alive_interval_s } @@ -141,17 +86,12 @@ impl ConnectOptions { pub fn set_pending_throttle_s(&mut self, pending_throttle_s: u64) { self.pending_throttle_s = pending_throttle_s } - pub fn set_send_reason_messages(&mut self, send_reason_messages: bool) { - self.send_reason_messages = send_reason_messages - } - pub fn set_session_expiry_interval(&mut self, session_expiry_interval: u32) { self.session_expiry_interval = Some(session_expiry_interval) } pub fn clear_session_expiry_interval(&mut self) { self.session_expiry_interval = None } - pub fn set_receive_maximum(&mut self, receive_maximum: u16) { self.receive_maximum = Some(receive_maximum) } @@ -161,7 +101,6 @@ impl ConnectOptions { pub fn receive_maximum(&self) -> u16 { self.receive_maximum.unwrap_or(RECEIVE_MAXIMUM_DEFAULT) } - pub fn set_maximum_packet_size(&mut self, maximum_packet_size: u32) { self.maximum_packet_size = Some(maximum_packet_size) } @@ -175,13 +114,13 @@ impl ConnectOptions { self.topic_alias_maximum = None } pub fn set_request_response_information(&mut self, request_response_information: bool) { - self.request_response_information = Some(if request_response_information { 1 } else { 0 }) + self.request_response_information = Some(u8::from(request_response_information)) } pub fn clear_request_response_information(&mut self) { self.request_response_information = None } pub fn set_request_problem_information(&mut self, request_problem_information: bool) { - self.request_problem_information = Some(if request_problem_information { 1 } else { 0 }) + self.request_problem_information = Some(u8::from(request_problem_information)) } pub fn clear_request_problem_information(&mut self) { self.request_problem_information = None diff --git a/src/connections/async_native_tls.rs b/src/connections/async_native_tls.rs deleted file mode 100644 index 8ab5e56..0000000 --- a/src/connections/async_native_tls.rs +++ /dev/null @@ -1,236 +0,0 @@ -use std::io::{self, Error, ErrorKind}; - -use async_channel::Receiver; -use async_native_tls::{TlsConnector, TlsStream}; -use bytes::{Buf, BytesMut}; -use smol::io::{ReadHalf, WriteHalf}; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt}, - net::TcpStream, -}; - -use tracing::trace; - -use crate::error::TlsError; -use crate::{ - connect_options::ConnectOptions, connections::create_connect_from_options, - error::ConnectionError, network::Incoming, -}; -use crate::{ - connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite}, - packets::{ - error::ReadBytes, - packets::{FixedHeader, Packet, PacketType}, - reason_codes::ConnAckReasonCode, - }, -}; - -#[derive(Debug)] -pub struct TlsReader { - readhalf: ReadHalf>, - - /// Buffered reads - buffer: BytesMut, -} - -impl TlsReader { - pub async fn new(options: &ConnectOptions) -> Result<(TlsReader, TlsWriter), ConnectionError> { - if let Some(tls_config) = &options.tls_config { - let addr = options.address.clone(); - let tcp = TcpStream::connect((addr.as_str(), options.port)).await?; - - let connector = TlsConnector::new().use_sni(true); - - let mut connection = async_native_tls::connect("google.com", tcp).await.unwrap(); - - let (readhalf, writehalf) = split(connection); - - let reader = Self { - readhalf, - buffer: BytesMut::with_capacity(20 * 1024), - }; - let writer = TlsWriter::new(writehalf); - - Ok((reader, writer)) - } else { - Err(ConnectionError::TLS(TlsError::NoTlsConfig)) - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()) - .map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - #[cfg(feature = "tokio")] - let read = self.readhalf.read_buf(&mut self.buffer).await?; - #[cfg(feature = "smol")] - let read = self.readhalf.read(&mut self.buffer).await?; - if 0 == read { - return if self.buffer.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Connection reset by peer", - )) - }; - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } -} - -impl AsyncMqttNetworkRead for TlsReader { - type W = TlsWriter; - - fn connect( - options: &ConnectOptions, - ) -> impl std::future::Future> + Send + '_ - { - async { - let (mut reader, mut writer) = TlsReader::new(options).await?; - - let mut buf_out = BytesMut::new(); - - create_connect_from_options(options).write(&mut buf_out)?; - - writer.write_buffer(&mut buf_out).await?; - - let packet = reader.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((reader, writer, Packet::ConnAck(con))) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - } - - async fn read(&mut self) -> Result { - Ok(self.read().await?) - } - - async fn read_direct( - &mut self, - incoming_packet_sender: &async_channel::Sender, - ) -> Result { - let mut read_packets = 0; - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - tracing::trace!("Read packet from network {}", read_packet); - let disconnect = read_packet.packet_type() == PacketType::Disconnect; - incoming_packet_sender.send(read_packet).await?; - if disconnect { - return Ok(true); - } - read_packets += 1; - if read_packets >= 10 { - return Ok(false); - } - } - } -} -pub struct TlsWriter { - writehalf: WriteHalf>, - - buffer: BytesMut, -} - -impl TlsWriter { - pub fn new(writehalf: WriteHalf>) -> Self { - Self { - writehalf, - buffer: BytesMut::with_capacity(20 * 1024), - } - } -} - -impl AsyncMqttNetworkWrite for TlsWriter { - async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { - if buffer.is_empty() { - return Ok(()); - } - - self.writehalf.write_all(&buffer[..]).await?; - buffer.clear(); - Ok(()) - } - - async fn write(&mut self, outgoing: &Receiver) -> Result { - let mut disconnect = false; - - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - } - - while !outgoing.is_empty() && !disconnect { - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - break; - } - trace!("Going to write packet to network: {:?}", packet); - } - - self.writehalf.write_all(&self.buffer[..]).await?; - self.writehalf.flush().await?; - self.buffer.clear(); - Ok(disconnect) - } -} diff --git a/src/connections/async_rustls.rs b/src/connections/async_rustls.rs deleted file mode 100644 index 2e6f9f5..0000000 --- a/src/connections/async_rustls.rs +++ /dev/null @@ -1,297 +0,0 @@ -use std::io::{self, Error, ErrorKind}; - -use async_channel::Receiver; - -use async_rustls::client::TlsStream; -use async_rustls::TlsConnector; -use bytes::{Buf, BytesMut}; -use rustls::ServerName; -use smol::io::{ReadHalf, WriteHalf}; -use smol::{ - io::{split, AsyncReadExt, AsyncWriteExt}, - net::TcpStream, -}; - -use tracing::trace; - -use crate::error::TlsError; -use crate::{ - connect_options::ConnectOptions, connections::create_connect_from_options, - error::ConnectionError, network::Incoming, -}; -use crate::{ - connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite}, - packets::{ - error::ReadBytes, - {FixedHeader, Packet, PacketType}, - reason_codes::ConnAckReasonCode, - }, -}; - -use super::transport::{RustlsConfig, TlsConfig}; -use super::util::simple_rust_tls; - -#[derive(Debug)] -pub struct TlsReader { - readhalf: ReadHalf>, - - /// Input buffer - const_buffer: [u8; 1000], - /// Buffered reads - buffer: BytesMut, -} - -impl TlsReader { - pub async fn new(options: &ConnectOptions) -> Result<(TlsReader, TlsWriter), TlsError> { - if let Some(tls_config) = &options.tls_config { - let arc_tls_config = match tls_config { - TlsConfig::Rustls(RustlsConfig::Simple { - ca, - alpn, - client_auth, - }) => simple_rust_tls(ca.clone(), alpn.clone(), client_auth.clone())?, - TlsConfig::Rustls(RustlsConfig::Rustls(config)) => config.clone(), - }; - - let domain = ServerName::try_from(options.address.as_str())?; - let connector = TlsConnector::from(arc_tls_config); - - let stream = TcpStream::connect((options.address.as_str(), options.port)).await?; - let connection = connector.connect(domain, stream).await?; - - let (readhalf, writehalf) = split(connection); - - let reader = Self { - readhalf, - const_buffer: [0; 1000], - buffer: BytesMut::with_capacity(20 * 1024), - }; - let writer = TlsWriter::new(writehalf); - - Ok((reader, writer)) - } else { - Err(TlsError::NoTlsConfig) - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()) - .map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - let read = self.readhalf.read(&mut self.const_buffer).await?; - - if read == 0 { - return if self.buffer.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Connection reset by peer", - )) - }; - } - else { - self.buffer.extend_from_slice(&self.const_buffer[0..read]); - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } -} - -impl AsyncMqttNetworkRead for TlsReader { - type W = TlsWriter; - - fn connect( - options: &ConnectOptions, - ) -> impl std::future::Future> + Send + '_ - { - async { - let (mut reader, mut writer) = TlsReader::new(options).await?; - - let mut buf_out = BytesMut::new(); - - create_connect_from_options(options).write(&mut buf_out)?; - - writer.write_buffer(&mut buf_out).await?; - - let packet = reader.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((reader, writer, Packet::ConnAck(con))) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - } - - async fn read(&mut self) -> Result { - Ok(self.read().await?) - } - - async fn read_direct( - &mut self, - incoming_packet_sender: &async_channel::Sender, - ) -> Result { - let mut read_packets = 0; - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - tracing::trace!("Read packet from network {}", read_packet); - let disconnect = read_packet.packet_type() == PacketType::Disconnect; - incoming_packet_sender.send(read_packet).await?; - if disconnect { - return Ok(true); - } - read_packets += 1; - if read_packets >= 10 { - return Ok(false); - } - } - } -} - -#[derive(Debug)] -pub struct TlsWriter { - writehalf: WriteHalf>, - - buffer: BytesMut, -} - -impl TlsWriter { - pub fn new(writehalf: WriteHalf>) -> Self { - Self { - writehalf, - buffer: BytesMut::with_capacity(20 * 1024), - } - } -} - -impl AsyncMqttNetworkWrite for TlsWriter { - async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { - if buffer.is_empty() { - return Ok(()); - } - - self.writehalf.write_all(&buffer[..]).await?; - self.writehalf.flush().await?; - buffer.clear(); - Ok(()) - } - - async fn write(&mut self, outgoing: &Receiver) -> Result { - let mut disconnect = false; - - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - } - - while !outgoing.is_empty() && !disconnect { - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - break; - } - trace!("Going to write packet to network: {:?}", packet); - } - - self.writehalf.write_all(&self.buffer[..]).await?; - self.writehalf.flush().await?; - self.buffer.clear(); - Ok(disconnect) - } -} - -#[cfg(test)] -mod test { - use crate::{ - connect_options::ConnectOptions, - connections::{ - transport::{RustlsConfig, TlsConfig}, - AsyncMqttNetworkRead, - }, - packets::{ - {Packet, PacketType}, - reason_codes::ConnAckReasonCode, - }, - tests::resources::EMQX_CERT, - }; - - use super::TlsReader; - - fn connect_emqx_test() { - let config = TlsConfig::Rustls(RustlsConfig::Simple { - ca: EMQX_CERT.to_vec(), - alpn: None, - client_auth: None, - }); - - let opt = ConnectOptions::new_with_tls_config( - "broker.emqx.io".to_string(), - 8883, - "test123123".to_string(), - Some(config), - ); - - let (_, _, packet) = smol::block_on(TlsReader::connect(&opt)).unwrap(); - - assert_eq!(PacketType::ConnAck, packet.packet_type()); - if let Packet::ConnAck(conn) = packet { - assert_eq!(ConnAckReasonCode::Success, conn.reason_code); - } - } -} diff --git a/src/connections/mod.rs b/src/connections/mod.rs index 1a15d6a..740a9ac 100644 --- a/src/connections/mod.rs +++ b/src/connections/mod.rs @@ -1,77 +1,26 @@ -#[cfg(all(feature = "tokio", feature = "tcp"))] -pub mod tokio_tcp; - -#[cfg(all(feature = "smol", feature = "tcp"))] -pub mod smol_tcp; - -#[cfg(all(feature = "smol", feature = "native-tls"))] -pub mod async_native_tls; - -#[cfg(all(feature = "smol", feature = "smol-rustls"))] -pub mod async_rustls; - #[cfg(all(feature = "quic"))] pub mod quic; - -#[cfg(all(feature = "tokio", feature = "tokio-rustls"))] -pub mod tokio_rustls; - -#[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))] -pub mod transport; - -#[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))] -mod util; - -use std::future::Future; - -use async_channel::{Receiver, Sender}; -use bytes::BytesMut; +#[cfg(feature = "smol")] +pub mod smol_stream; +#[cfg(feature = "tokio")] +pub mod tokio_stream; use crate::connect_options::ConnectOptions; -use crate::error::ConnectionError; use crate::packets::Connect; use crate::packets::Packet; pub fn create_connect_from_options(options: &ConnectOptions) -> Packet { - let mut connect = Connect::default(); - - connect.client_id = options.client_id.clone(); - connect.clean_session = options.clean_session; - connect.keep_alive = options.keep_alive_interval_s as u16; - connect.connect_properties.request_problem_information = Some(1u8); - connect.connect_properties.request_response_information = Some(1u8); - connect.username = options.username.clone(); - connect.password = options.password.clone(); + let mut connect = Connect { + client_id: options.client_id.clone(), + clean_session: options.clean_session, + keep_alive: options.keep_alive_interval_s as u16, + username: options.username.clone(), + password: options.password.clone(), + ..Default::default() + }; + + connect.connect_properties.request_problem_information = options.request_problem_information; + connect.connect_properties.request_response_information = options.request_response_information; Packet::Connect(connect) } - -pub trait AsyncMqttNetwork: Sized + Sync + 'static { - fn connect( - options: &ConnectOptions, - ) -> impl Future> + Send + '_; - - async fn read(&self) -> Result; - - async fn read_many(&self, receiver: &Sender) -> Result<(), ConnectionError>; - - async fn write(&self, write_buf: &mut BytesMut) -> Result<(), ConnectionError>; -} - -pub trait AsyncMqttNetworkRead: Sized + Sync { - type W; - - fn connect( - options: &ConnectOptions, - ) -> impl Future> + Send + '_; - - async fn read(&mut self) -> Result; - - async fn read_direct(&mut self, sender: &Sender) -> Result; -} - -pub trait AsyncMqttNetworkWrite: Sized + Sync { - async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError>; - - async fn write(&mut self, outgoing: &Receiver) -> Result; -} diff --git a/src/connections/smol_stream.rs b/src/connections/smol_stream.rs new file mode 100644 index 0000000..f3b4cc5 --- /dev/null +++ b/src/connections/smol_stream.rs @@ -0,0 +1,176 @@ +use std::io::{self, Error, ErrorKind}; + +use bytes::{Buf, BytesMut}; +use smol::io::{AsyncReadExt, AsyncWriteExt}; + +use futures::{AsyncRead, AsyncWrite}; + +use tracing::trace; + +use crate::packets::{ + error::ReadBytes, + reason_codes::ConnAckReasonCode, + {FixedHeader, Packet, PacketType}, +}; +use crate::{ + connect_options::ConnectOptions, connections::create_connect_from_options, + error::ConnectionError, +}; + +#[derive(Debug)] +pub struct SmolStream { + pub stream: S, + + /// Input buffer + const_buffer: [u8; 1000], + /// Buffered reads + buffer: BytesMut, +} + +impl SmolStream +where + S: AsyncRead + AsyncWrite + Sized + Unpin, +{ + pub async fn connect( + options: &ConnectOptions, + stream: S, + ) -> Result<(Self, Packet), ConnectionError> { + let mut s = Self { + stream, + const_buffer: [0; 1000], + buffer: BytesMut::new(), + }; + + let mut buf_out = BytesMut::new(); + + create_connect_from_options(options).write(&mut buf_out)?; + + s.write_buffer(&mut buf_out).await?; + + let packet = s.read().await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + Ok((s, Packet::ConnAck(con))) + } + else { + Err(ConnectionError::ConnectionRefused(con.reason_code)) + } + } + else { + Err(ConnectionError::NotConnAck(packet)) + } + } + + pub async fn parse_messages( + &mut self, + incoming_packet_sender: &async_channel::Sender, + ) -> Result, ReadBytes> { + let mut ret_packet_type = None; + loop { + if self.buffer.is_empty() { + return Ok(ret_packet_type); + } + let (header, header_length) = FixedHeader::read_fixed_header(self.buffer.iter())?; + + if header.remaining_length > self.buffer.len() { + return Err(ReadBytes::InsufficientBytes( + header.remaining_length - self.buffer.len(), + )); + } + + self.buffer.advance(header_length); + + let buf = self.buffer.split_to(header.remaining_length); + let read_packet = Packet::read(header, buf.into())?; + tracing::trace!("Read packet from network {}", read_packet); + let packet_type = read_packet.packet_type(); + incoming_packet_sender.send(read_packet).await?; + + match packet_type { + PacketType::Disconnect => return Ok(Some(PacketType::Disconnect)), + PacketType::PingResp => return Ok(Some(PacketType::PingResp)), + packet_type => ret_packet_type = Some(packet_type), + } + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { + Ok(header) => header, + Err(ReadBytes::InsufficientBytes(required_len)) => { + self.read_required_bytes(required_len).await?; + continue; + } + Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), + }; + + self.buffer.advance(header_length); + + if header.remaining_length > self.buffer.len() { + self.read_required_bytes(header.remaining_length - self.buffer.len()) + .await?; + } + + let buf = self.buffer.split_to(header.remaining_length); + + return Packet::read(header, buf.into()) + .map_err(|err| Error::new(ErrorKind::InvalidData, err)); + } + } + + pub async fn read_bytes(&mut self) -> io::Result { + let read = self.stream.read(&mut self.const_buffer).await?; + if 0 == read { + if self.buffer.is_empty() { + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "Connection closed by peer", + )) + } + else { + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "Connection reset by peer", + )) + } + } + else { + self.buffer.extend_from_slice(&self.const_buffer[0..read]); + Ok(read) + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + + loop { + let read = self.read_bytes().await?; + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { + if buffer.is_empty() { + return Ok(()); + } + + self.stream.write_all(&buffer[..]).await?; + buffer.clear(); + Ok(()) + } + + pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { + packet.write(&mut self.buffer)?; + trace!("Sending packet {}", packet); + + self.stream.write_all(&self.buffer[..]).await?; + self.stream.flush().await?; + self.buffer.clear(); + Ok(()) + } +} diff --git a/src/connections/smol_tcp.rs b/src/connections/smol_tcp.rs deleted file mode 100644 index 17ebaf6..0000000 --- a/src/connections/smol_tcp.rs +++ /dev/null @@ -1,220 +0,0 @@ - -use std::io::{self, Error, ErrorKind}; - -use async_channel::Receiver; -use bytes::{Buf, BytesMut}; -// #[cfg(feature = "smol")] -use smol::{ - io::{AsyncReadExt, AsyncWriteExt, split, ReadHalf, WriteHalf}, - net::TcpStream, -}; -use tracing::trace; - -use crate::packets::{ - error::ReadBytes, - {FixedHeader, Packet, PacketType}, - reason_codes::ConnAckReasonCode, -}; -use crate::{ - connect_options::ConnectOptions, - connections::{create_connect_from_options, AsyncMqttNetworkWrite}, - error::ConnectionError, - network::Incoming, -}; - -use super::AsyncMqttNetworkRead; - -#[derive(Debug)] -pub struct TcpReader { - readhalf: ReadHalf, - - /// Input buffer - const_buffer: [u8; 1000], - /// Buffered reads - buffer: BytesMut, -} - -impl TcpReader { - pub async fn new_tcp( - options: &ConnectOptions, - ) -> Result<(TcpReader, TcpWriter), ConnectionError> { - let (readhalf, writehalf) = split(TcpStream::connect((options.address.clone(), options.port)).await?); - let reader = TcpReader { - readhalf, - const_buffer: [0; 1000], - buffer: BytesMut::with_capacity(20 * 1024), - // max_incoming_size: u32::MAX as usize, - }; - let writer = TcpWriter::new(writehalf); - Ok((reader, writer)) - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()) - .map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - let read = self.readhalf.read(&mut self.const_buffer).await?; - if 0 == read { - return if self.buffer.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Connection reset by peer", - )) - }; - } - else { - self.buffer.extend_from_slice(&self.const_buffer[0..read]); - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } -} - -impl AsyncMqttNetworkRead for TcpReader { - type W = TcpWriter; - - fn connect( - options: &ConnectOptions, - ) -> impl std::future::Future> + Send + '_ - { - async { - let (mut reader, mut writer) = TcpReader::new_tcp(options).await?; - - let mut buf_out = BytesMut::new(); - - create_connect_from_options(options).write(&mut buf_out)?; - - writer.write_buffer(&mut buf_out).await?; - - let packet = reader.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((reader, writer, Packet::ConnAck(con))) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - } - - async fn read(&mut self) -> Result { - Ok(self.read().await?) - } - - async fn read_direct( - &mut self, - incoming_packet_sender: &async_channel::Sender, - ) -> Result { - let mut read_packets = 0; - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - tracing::trace!("Read packet from network {}", read_packet); - let disconnect = read_packet.packet_type() == PacketType::Disconnect; - incoming_packet_sender.send(read_packet).await?; - if disconnect { - return Ok(true); - } - read_packets += 1; - if read_packets >= 10 { - return Ok(false); - } - } - } -} - -pub struct TcpWriter { - writehalf: WriteHalf, - - buffer: BytesMut, -} - -impl TcpWriter { - pub fn new(writehalf: WriteHalf) -> Self { - Self { - writehalf, - buffer: BytesMut::with_capacity(20 * 1024), - } - } -} - -impl AsyncMqttNetworkWrite for TcpWriter { - async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { - if buffer.is_empty() { - return Ok(()); - } - - self.writehalf.write_all(&buffer[..]).await?; - buffer.clear(); - Ok(()) - } - - async fn write(&mut self, outgoing: &Receiver) -> Result { - let mut disconnect = false; - - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - } - trace!("Sending packet {}", packet); - - self.writehalf.write_all(&self.buffer[..]).await?; - self.writehalf.flush().await?; - self.buffer.clear(); - Ok(disconnect) - } -} diff --git a/src/connections/tokio_rustls.rs b/src/connections/tokio_rustls.rs deleted file mode 100644 index a09bc58..0000000 --- a/src/connections/tokio_rustls.rs +++ /dev/null @@ -1,290 +0,0 @@ -use std::io::{self, Error, ErrorKind}; - -use async_channel::Receiver; - -use bytes::{Buf, BytesMut}; -use rustls::ServerName; - -use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::TcpStream; -use tokio_rustls::client::TlsStream; -use tokio_rustls::TlsConnector; -use tracing::trace; - -use crate::error::TlsError; -use crate::{ - connect_options::ConnectOptions, connections::create_connect_from_options, - error::ConnectionError, network::Incoming, -}; -use crate::{ - connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite}, - packets::{ - error::ReadBytes, - {FixedHeader, Packet, PacketType}, - reason_codes::ConnAckReasonCode, - }, -}; - -use super::transport::{RustlsConfig, TlsConfig}; -use super::util::simple_rust_tls; - -#[derive(Debug)] -pub struct TlsReader { - readhalf: ReadHalf>, - - /// Buffered reads - buffer: BytesMut, -} - -impl TlsReader { - pub async fn new(options: &ConnectOptions) -> Result<(TlsReader, TlsWriter), TlsError> { - if let Some(tls_config) = &options.tls_config { - let arc_tls_config = match tls_config { - TlsConfig::Rustls(RustlsConfig::Simple { - ca, - alpn, - client_auth, - }) => simple_rust_tls(ca.clone(), alpn.clone(), client_auth.clone())?, - TlsConfig::Rustls(RustlsConfig::Rustls(config)) => config.clone(), - }; - - let domain = ServerName::try_from(options.address.as_str())?; - let connector = TlsConnector::from(arc_tls_config); - - let stream = TcpStream::connect((options.address.as_str(), options.port)).await?; - let connection = connector.connect(domain, stream).await?; - - let (readhalf, writehalf) = tokio::io::split(connection); - - let reader = Self { - readhalf, - buffer: BytesMut::with_capacity(20 * 1024), - }; - let writer = TlsWriter::new(writehalf); - - Ok((reader, writer)) - } else { - Err(TlsError::NoTlsConfig) - } - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()) - .map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - loop { - let read = self.readhalf.read_buf(&mut self.buffer).await?; - - if read == 0 { - return if self.buffer.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Connection reset by peer", - )) - }; - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } -} - -impl AsyncMqttNetworkRead for TlsReader { - type W = TlsWriter; - - fn connect( - options: &ConnectOptions, - ) -> impl std::future::Future> + Send + '_ - { - async { - let (mut reader, mut writer) = TlsReader::new(options).await?; - - let mut buf_out = BytesMut::new(); - - create_connect_from_options(options).write(&mut buf_out)?; - - writer.write_buffer(&mut buf_out).await?; - - if !buf_out.is_empty() { - panic!("Should be empty"); - } - - let packet = reader.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((reader, writer, Packet::ConnAck(con))) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - } - - async fn read(&mut self) -> Result { - Ok(self.read().await?) - } - - async fn read_direct( - &mut self, - incoming_packet_sender: &async_channel::Sender, - ) -> Result { - let mut read_packets = 0; - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - tracing::trace!("Read packet from network {}", read_packet); - let disconnect = read_packet.packet_type() == PacketType::Disconnect; - incoming_packet_sender.send(read_packet).await?; - if disconnect { - return Ok(true); - } - read_packets += 1; - if read_packets >= 10 { - return Ok(false); - } - } - } -} - -#[derive(Debug)] -pub struct TlsWriter { - writehalf: WriteHalf>, - buffer: BytesMut, -} - -impl TlsWriter { - pub fn new(writehalf: WriteHalf>) -> Self { - Self { - writehalf, - buffer: BytesMut::with_capacity(20 * 1024), - } - } -} - -impl AsyncMqttNetworkWrite for TlsWriter { - async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { - if buffer.is_empty() { - return Ok(()); - } - - self.writehalf.write_all(&buffer[..]).await?; - self.writehalf.flush().await?; - buffer.clear(); - Ok(()) - } - - async fn write(&mut self, outgoing: &Receiver) -> Result { - let mut disconnect = false; - - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - } - - while !outgoing.is_empty() && !disconnect { - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - break; - } - trace!("Going to write packet to network: {:?}", packet); - } - - self.writehalf.write_all(&self.buffer[..]).await?; - self.writehalf.flush().await?; - self.buffer.clear(); - Ok(disconnect) - } -} - -#[cfg(test)] -mod test { - use crate::{ - connect_options::ConnectOptions, - connections::{ - transport::{RustlsConfig, TlsConfig}, - AsyncMqttNetworkRead, - }, - packets::{ - {Packet, PacketType}, - reason_codes::ConnAckReasonCode, - }, - tests::resources::EMQX_CERT, - }; - - use super::TlsReader; - - async fn connect_emqx_test() { - let config = TlsConfig::Rustls(RustlsConfig::Simple { - ca: EMQX_CERT.to_vec(), - alpn: None, - client_auth: None, - }); - - let opt = ConnectOptions::new_with_tls_config( - "broker.emqx.io".to_string(), - 8883, - "test123123".to_string(), - Some(config), - ); - - let (_, _, packet) = TlsReader::connect(&opt).await.unwrap(); - - assert_eq!(PacketType::ConnAck, packet.packet_type()); - if let Packet::ConnAck(conn) = packet { - assert_eq!(ConnAckReasonCode::Success, conn.reason_code); - } - } -} diff --git a/src/connections/tokio_stream.rs b/src/connections/tokio_stream.rs new file mode 100644 index 0000000..954e87c --- /dev/null +++ b/src/connections/tokio_stream.rs @@ -0,0 +1,174 @@ +use std::io::{self, Error, ErrorKind}; + +use bytes::{Buf, BytesMut}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use tracing::trace; + +use crate::packets::{ + error::ReadBytes, + reason_codes::ConnAckReasonCode, + {FixedHeader, Packet, PacketType}, +}; +use crate::{ + connect_options::ConnectOptions, connections::create_connect_from_options, + error::ConnectionError, +}; + +#[derive(Debug)] +pub struct TokioStream { + pub stream: S, + + /// Input buffer + const_buffer: [u8; 1000], + /// Buffered reads + buffer: BytesMut, +} + +impl TokioStream +where + S: AsyncRead + AsyncWrite + Sized + Unpin, +{ + pub async fn connect( + options: &ConnectOptions, + stream: S, + ) -> Result<(Self, Packet), ConnectionError> { + let mut s = Self { + stream, + const_buffer: [0; 1000], + buffer: BytesMut::new(), + }; + + let mut buf_out = BytesMut::new(); + + create_connect_from_options(options).write(&mut buf_out)?; + + s.write_buffer(&mut buf_out).await?; + + let packet = s.read().await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + Ok((s, Packet::ConnAck(con))) + } + else { + Err(ConnectionError::ConnectionRefused(con.reason_code)) + } + } + else { + Err(ConnectionError::NotConnAck(packet)) + } + } + + pub async fn parse_messages( + &mut self, + incoming_packet_sender: &async_channel::Sender, + ) -> Result, ReadBytes> { + let mut ret_packet_type = None; + loop { + if self.buffer.is_empty() { + return Ok(ret_packet_type); + } + let (header, header_length) = FixedHeader::read_fixed_header(self.buffer.iter())?; + + if header.remaining_length > self.buffer.len() { + return Err(ReadBytes::InsufficientBytes( + header.remaining_length - self.buffer.len(), + )); + } + + self.buffer.advance(header_length); + + let buf = self.buffer.split_to(header.remaining_length); + let read_packet = Packet::read(header, buf.into())?; + tracing::trace!("Read packet from network {}", read_packet); + let packet_type = read_packet.packet_type(); + incoming_packet_sender.send(read_packet).await?; + + match packet_type { + PacketType::Disconnect => return Ok(Some(PacketType::Disconnect)), + PacketType::PingResp => return Ok(Some(PacketType::PingResp)), + packet_type => ret_packet_type = Some(packet_type), + } + } + } + + pub async fn read(&mut self) -> io::Result { + loop { + let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { + Ok(header) => header, + Err(ReadBytes::InsufficientBytes(required_len)) => { + self.read_required_bytes(required_len).await?; + continue; + } + Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), + }; + + self.buffer.advance(header_length); + + if header.remaining_length > self.buffer.len() { + self.read_required_bytes(header.remaining_length - self.buffer.len()) + .await?; + } + + let buf = self.buffer.split_to(header.remaining_length); + + return Packet::read(header, buf.into()) + .map_err(|err| Error::new(ErrorKind::InvalidData, err)); + } + } + + pub async fn read_bytes(&mut self) -> io::Result { + let read = self.stream.read(&mut self.const_buffer).await?; + if 0 == read { + if self.buffer.is_empty() { + Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "Connection closed by peer", + )) + } + else { + Err(io::Error::new( + io::ErrorKind::ConnectionReset, + "Connection reset by peer", + )) + } + } + else { + self.buffer.extend_from_slice(&self.const_buffer[0..read]); + Ok(read) + } + } + + /// Reads more than 'required' bytes to frame a packet into self.read buffer + pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { + let mut total_read = 0; + + loop { + let read = self.read_bytes().await?; + total_read += read; + if total_read >= required { + return Ok(total_read); + } + } + } + + pub async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { + if buffer.is_empty() { + return Ok(()); + } + + self.stream.write_all(&buffer[..]).await?; + buffer.clear(); + Ok(()) + } + + pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { + packet.write(&mut self.buffer)?; + trace!("Sending packet {}", packet); + + self.stream.write_all(&self.buffer[..]).await?; + self.stream.flush().await?; + self.buffer.clear(); + Ok(()) + } +} diff --git a/src/connections/tokio_tcp.rs b/src/connections/tokio_tcp.rs deleted file mode 100644 index 5a90ba5..0000000 --- a/src/connections/tokio_tcp.rs +++ /dev/null @@ -1,223 +0,0 @@ -use std::io::{self, Error, ErrorKind}; - -use async_channel::Receiver; -use bytes::{Buf, BytesMut}; -// #[cfg(feature = "smol")] -// use smol::{ -// io::{AsyncReadExt, AsyncWriteExt}, -// net::TcpStream, -// }; -use tokio::{io::AsyncReadExt, net::TcpStream}; -use tokio::{ - io::AsyncWriteExt, - net::tcp::{OwnedReadHalf, OwnedWriteHalf}, -}; -use tracing::trace; - -use crate::packets::{ - error::ReadBytes, - {FixedHeader, Packet, PacketType}, - reason_codes::ConnAckReasonCode, -}; -use crate::{ - connect_options::ConnectOptions, - connections::{create_connect_from_options, AsyncMqttNetworkWrite}, - error::ConnectionError, - network::Incoming, -}; - -use super::AsyncMqttNetworkRead; - -#[derive(Debug)] -pub struct TcpReader { - readhalf: OwnedReadHalf, - - /// Buffered reads - buffer: BytesMut, -} - -impl TcpReader { - pub async fn new_tcp( - options: &ConnectOptions, - ) -> Result<(TcpReader, TcpWriter), ConnectionError> { - let (readhalf, writehalf) = TcpStream::connect((options.address.clone(), options.port)) - .await? - .into_split(); - let reader = TcpReader { - readhalf, - buffer: BytesMut::with_capacity(20 * 1024), - // max_incoming_size: u32::MAX as usize, - }; - let writer = TcpWriter::new(writehalf); - Ok((reader, writer)) - } - - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()) - .map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - /// Reads more than 'required' bytes to frame a packet into self.read buffer - async fn read_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - #[cfg(feature = "tokio")] - let read = self.readhalf.read_buf(&mut self.buffer).await?; - // #[cfg(feature = "smol")] - // let read = self.connection.read(&mut self.buffer).await?; - if 0 == read { - return if self.buffer.is_empty() { - Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "Connection closed by peer", - )) - } else { - Err(io::Error::new( - io::ErrorKind::ConnectionReset, - "Connection reset by peer", - )) - }; - } - - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } -} - -impl AsyncMqttNetworkRead for TcpReader { - type W = TcpWriter; - - fn connect( - options: &ConnectOptions, - ) -> impl std::future::Future> + Send + '_ - { - async { - let (mut reader, mut writer) = TcpReader::new_tcp(options).await?; - - let mut buf_out = BytesMut::new(); - - create_connect_from_options(options).write(&mut buf_out)?; - - writer.write_buffer(&mut buf_out).await?; - - let packet = reader.read().await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - Ok((reader, writer, Packet::ConnAck(con))) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } - } else { - Err(ConnectionError::NotConnAck(packet)) - } - } - } - - async fn read(&mut self) -> Result { - Ok(self.read().await?) - } - - async fn read_direct( - &mut self, - incoming_packet_sender: &async_channel::Sender, - ) -> Result { - let mut read_packets = 0; - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)), - }; - - self.buffer.advance(header_length); - - if header.remaining_length > self.buffer.len() { - self.read_bytes(header.remaining_length - self.buffer.len()) - .await?; - } - - let buf = self.buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - tracing::trace!("Read packet from network {}", read_packet); - let disconnect = read_packet.packet_type() == PacketType::Disconnect; - incoming_packet_sender.send(read_packet).await?; - if disconnect { - return Ok(true); - } - read_packets += 1; - if read_packets >= 10 { - return Ok(false); - } - } - } -} - -pub struct TcpWriter { - writehalf: OwnedWriteHalf, - - buffer: BytesMut, -} - -impl TcpWriter { - pub fn new(writehalf: OwnedWriteHalf) -> Self { - Self { - writehalf, - buffer: BytesMut::with_capacity(20 * 1024), - } - } -} - -impl AsyncMqttNetworkWrite for TcpWriter { - async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> { - if buffer.is_empty() { - return Ok(()); - } - - self.writehalf.write_all(&buffer[..]).await?; - buffer.clear(); - Ok(()) - } - - async fn write(&mut self, outgoing: &Receiver) -> Result { - let mut disconnect = false; - - let packet = outgoing.recv().await?; - packet.write(&mut self.buffer)?; - if packet.packet_type() == PacketType::Disconnect { - disconnect = true; - } - trace!("Sending packet {}", packet); - - self.writehalf.write_all(&self.buffer[..]).await?; - self.writehalf.flush().await?; - self.buffer.clear(); - Ok(disconnect) - } -} diff --git a/src/connections/transport.rs b/src/connections/transport.rs deleted file mode 100644 index 9101622..0000000 --- a/src/connections/transport.rs +++ /dev/null @@ -1,31 +0,0 @@ -use rustls::ClientConfig; -use std::sync::Arc; - -#[derive(Debug, Clone)] -pub enum TlsConfig { - #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))] - Rustls(RustlsConfig), - #[cfg(feature = "native-tls")] - Native { - ca: Vec, - der: Vec, - password: String, - }, -} - -#[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))] -#[derive(Debug, Clone)] -pub enum RustlsConfig { - Simple { - ca: Vec, - alpn: Option>>, - client_auth: Option<(Vec, PrivateKey)>, - }, - Rustls(Arc), -} - -#[derive(Debug, Clone)] -pub enum PrivateKey { - RSA(Vec), - ECC(Vec), -} diff --git a/src/connections/util.rs b/src/connections/util.rs deleted file mode 100644 index 26adf12..0000000 --- a/src/connections/util.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::{ - io::{BufReader, Cursor}, - sync::Arc, -}; - -use rustls::{Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore}; - -use super::transport::PrivateKey; -use crate::error::TlsError; - -// pub fn native() {} - -pub fn simple_rust_tls( - ca: Vec, - alpn: Option>>, - client_auth: Option<(Vec, PrivateKey)>, -) -> Result, TlsError> { - let mut root_cert_store = RootCertStore::empty(); - - let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))?; - - let trust_anchors = ca_certs.iter().map_while(|cert| { - if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { - Some(OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - )) - } else { - None - } - }); - root_cert_store.add_server_trust_anchors(trust_anchors); - - if root_cert_store.is_empty() { - return Err(TlsError::NoValidRootCertInChain); - } - - let config = ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store); - - let mut config = match client_auth { - Some((client_cert_info, client_private_info)) => { - let read_private_keys = match client_private_info { - PrivateKey::RSA(rsa) => { - rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))) - } - PrivateKey::ECC(ecc) => { - rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))) - } - } - .map_err(|_| TlsError::NoValidPrivateKey)?; - - let key = read_private_keys - .into_iter() - .next() - .ok_or(TlsError::NoValidPrivateKey)?; - - let client_certs = - rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info)))?; - let client_cert_chain = client_certs.into_iter().map(Certificate).collect(); - - config.with_single_cert(client_cert_chain, rustls::PrivateKey(key))? - } - None => config.with_no_client_auth(), - }; - - if let Some(alpn) = alpn { - config.alpn_protocols.extend(alpn) - } - - Ok(Arc::new(config)) -} diff --git a/src/error.rs b/src/error.rs index dd93d94..296dce3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,13 +2,10 @@ use std::io; use async_channel::{RecvError, SendError}; -use crate::{ - packets::{ - error::{DeserializeError, SerializeError}, - {Packet, PacketType}, - reason_codes::ConnAckReasonCode, - }, - util::timeout::Timeout, +use crate::packets::{ + error::{DeserializeError, ReadBytes, SerializeError}, + reason_codes::ConnAckReasonCode, + {Packet, PacketType}, }; #[derive(Debug, Clone, thiserror::Error)] @@ -44,14 +41,8 @@ pub enum ClientError { /// Critical errors during eventloop polling #[derive(Debug, thiserror::Error)] pub enum ConnectionError { - // #[error("Mqtt state: {0}")] - // MqttState(#[from] StateError), - #[error("Connect timeout")] - Timeout(#[from] Timeout), - - #[cfg(feature = "use-rustls")] - #[error("TLS: {0}")] - Tls(#[from] tls::Error), + #[error("No network connection")] + NoNetwork, #[error("No incoming packet handler available: {0}")] NoIncomingPacketHandler(#[from] SendError), @@ -76,9 +67,27 @@ pub enum ConnectionError { #[error("Requests done")] RequestsDone, +} + +impl From> for ReadBytes { + fn from(value: ReadBytes) -> Self { + match value { + ReadBytes::Err(err) => ReadBytes::Err(err.into()), + ReadBytes::InsufficientBytes(id) => ReadBytes::InsufficientBytes(id), + } + } +} + +impl From for ReadBytes { + fn from(value: DeserializeError) -> Self { + ReadBytes::Err(value.into()) + } +} - #[error("TLS Error")] - TLS(#[from] TlsError), +impl From> for ReadBytes { + fn from(value: SendError) -> Self { + ReadBytes::Err(value.into()) + } } #[derive(Debug, thiserror::Error)] diff --git a/src/event_handler.rs b/src/event_handler.rs index 5f81ad5..70035d5 100644 --- a/src/event_handler.rs +++ b/src/event_handler.rs @@ -1,36 +1,31 @@ use crate::connect_options::ConnectOptions; use crate::error::MqttError; +use crate::packets::reason_codes::{PubAckReasonCode, PubRecReasonCode}; use crate::packets::Disconnect; -use crate::packets::{Packet, PacketType}; -use crate::packets::{PubAck, PubAckProperties}; use crate::packets::PubComp; -use crate::packets::Publish; use crate::packets::PubRec; use crate::packets::PubRel; -use crate::packets::reason_codes::{PubAckReasonCode, PubRecReasonCode}; +use crate::packets::Publish; +use crate::packets::QoS; use crate::packets::SubAck; use crate::packets::Subscribe; use crate::packets::UnsubAck; use crate::packets::Unsubscribe; -use crate::packets::QoS; +use crate::packets::{Packet, PacketType}; +use crate::packets::{PubAck, PubAckProperties}; use crate::state::State; +use crate::{AsyncEventHandler, AsyncEventHandlerMut, HandlerStatus}; use futures::FutureExt; use async_channel::{Receiver, Sender}; -use async_mutex::Mutex; -use tracing::{error, debug}; +use tracing::error; -use std::future::Future; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::time::Instant; +#[cfg(test)] +use tracing::debug; /// Eventloop with all the state of a connection pub struct EventHandlerTask { - /// Options of the current mqtt connection - // options: ConnectOptions, - /// Current state of the connection state: State, network_receiver: Receiver, @@ -39,11 +34,6 @@ pub struct EventHandlerTask { client_to_handler_r: Receiver, - last_network_action: Arc>, - atomic_waiting_for_pingresp: AtomicBool, - waiting_for_pingresp: bool, - keep_alive_s: u64, - atomic_disconnect: AtomicBool, disconnect: bool, } @@ -57,7 +47,6 @@ impl EventHandlerTask { network_receiver: Receiver, network_sender: Sender, client_to_handler_r: Receiver, - last_network_action: Arc>, ) -> (Self, Receiver) { let (state, packet_id_channel) = State::new(options.receive_maximum()); @@ -69,90 +58,84 @@ impl EventHandlerTask { client_to_handler_r, - last_network_action, - - atomic_waiting_for_pingresp: AtomicBool::new(false), - waiting_for_pingresp: false, - - keep_alive_s: options.keep_alive_interval_s, - - atomic_disconnect: AtomicBool::new(false), disconnect: false, }; (task, packet_id_channel) } - - pub fn sync_handle( - &self, - handler: &mut H, - ) -> Result<(), MqttError> { - - match self.network_receiver.try_recv() { - Ok(event) => { - handler.handle(&event); - }, - Err(err) => { - if err.is_closed(){ - return Err(MqttError::IncomingNetworkChannelClosed) + // pub fn sync_handle(&self, handler: &mut H) -> Result<(), MqttError> { + // match self.network_receiver.try_recv() { + // Ok(event) => { + // handler.handle(&event); + // } + // Err(err) => { + // if err.is_closed() { + // return Err(MqttError::IncomingNetworkChannelClosed); + // } + // } + // } + // match self.client_to_handler_r.try_recv() { + // Ok(_) => {} + // Err(err) => { + // if err.is_closed() { + // return Err(MqttError::IncomingNetworkChannelClosed); + // } + // } + // } + // Ok(()) + // } + + pub async fn handle(&mut self, handler: &H) -> Result + where + H: AsyncEventHandler, { + futures::select! { + incoming = self.network_receiver.recv().fuse() => { + match incoming { + Ok(event) => { + // debug!("Event Handler, handling incoming packet: {}", event); + handler.handle(&event).await; + self.handle_incoming_packet(event).await?; + } + Err(_) => return Err(MqttError::IncomingNetworkChannelClosed), } - }, - } - match self.client_to_handler_r.try_recv() { - Ok(_) => { - - }, - Err(err) => { - if err.is_closed(){ - return Err(MqttError::IncomingNetworkChannelClosed) + if self.disconnect { + self.disconnect = true; + return Ok(HandlerStatus::IncomingDisconnect); } }, + outgoing = self.client_to_handler_r.recv().fuse() => { + match outgoing { + Ok(event) => { + // debug!("Event Handler, handling outgoing packet: {}", event); + self.handle_outgoing_packet(event).await? + } + Err(_) => return Err(MqttError::ClientChannelClosed), + } + if self.disconnect { + self.disconnect = true; + return Ok(HandlerStatus::OutgoingDisconnect); + } + } } - Ok(()) - // let keepalive = async { - // let initial_keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0); - // let mut keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0); - // loop { - // warn!("Awaiting PING sleep"); - // #[cfg(feature = "tokio")] - // tokio::time::sleep(keep_alive_duration).await; - - // if (!self.waiting_for_pingresp.load(Ordering::Acquire)) - // && self.last_network_action.lock().await.elapsed() - // >= initial_keep_alive_duration - // { - // self.network_sender.send(Packet::PingReq).await?; - // self.waiting_for_pingresp.store(true, Ordering::Release); - // keep_alive_duration = initial_keep_alive_duration; - // } else { - // keep_alive_duration = initial_keep_alive_duration - // - self.last_network_action.lock().await.elapsed(); - // } - // if self.disconnect.load(Ordering::Acquire) { - // return Ok::<(), MqttError>(()); - // } - // } - // }; + Ok(HandlerStatus::Active) } - - pub async fn handle( - &mut self, - handler: &mut H, - ) -> Result<(), MqttError> { + pub async fn handle_mut(&mut self, handler: &mut H) -> Result + where + H: AsyncEventHandlerMut, { futures::select! { incoming = self.network_receiver.recv().fuse() => { match incoming { Ok(event) => { // debug!("Event Handler, handling incoming packet: {}", event); handler.handle(&event).await; - self.handle_incoming_packet(event).await? + self.handle_incoming_packet(event).await?; } Err(_) => return Err(MqttError::IncomingNetworkChannelClosed), } - if self.atomic_disconnect.load(Ordering::Acquire) { - self.atomic_disconnect.store(false, Ordering::Release); - return Ok::<(), MqttError>(()); + if self.disconnect { + self.disconnect = true; + return Ok(HandlerStatus::IncomingDisconnect); } }, outgoing = self.client_to_handler_r.recv().fuse() => { @@ -163,43 +146,16 @@ impl EventHandlerTask { } Err(_) => return Err(MqttError::ClientChannelClosed), } - if self.atomic_disconnect.load(Ordering::Acquire) { - self.atomic_disconnect.store(false, Ordering::Release); - return Ok::<(), MqttError>(()); + if self.disconnect { + self.disconnect = true; + return Ok(HandlerStatus::OutgoingDisconnect); } } } - Ok(()) - // let keepalive = async { - // let initial_keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0); - // let mut keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0); - // loop { - // warn!("Awaiting PING sleep"); - // #[cfg(feature = "tokio")] - // tokio::time::sleep(keep_alive_duration).await; - - // if (!self.waiting_for_pingresp.load(Ordering::Acquire)) - // && self.last_network_action.lock().await.elapsed() - // >= initial_keep_alive_duration - // { - // self.network_sender.send(Packet::PingReq).await?; - // self.waiting_for_pingresp.store(true, Ordering::Release); - // keep_alive_duration = initial_keep_alive_duration; - // } else { - // keep_alive_duration = initial_keep_alive_duration - // - self.last_network_action.lock().await.elapsed(); - // } - // if self.disconnect.load(Ordering::Acquire) { - // return Ok::<(), MqttError>(()); - // } - // } - // }; + Ok(HandlerStatus::Active) } - async fn handle_incoming_packet( - &mut self, - packet: Packet, - ) -> Result<(), MqttError> { + async fn handle_incoming_packet(&mut self, packet: Packet) -> Result<(), MqttError> { match packet { Packet::Publish(publish) => self.handle_incoming_publish(&publish).await?, Packet::PubAck(puback) => self.handle_incoming_puback(&puback).await?, @@ -208,11 +164,11 @@ impl EventHandlerTask { Packet::PubComp(pubcomp) => self.handle_incoming_pubcomp(&pubcomp).await?, Packet::SubAck(suback) => self.handle_incoming_suback(suback).await?, Packet::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback).await?, - Packet::PingResp => self.handle_incoming_pingresp().await, + Packet::PingResp => (), Packet::ConnAck(_) => (), Packet::Disconnect(_) => { - self.atomic_disconnect.store(true, Ordering::Release); - }, + self.disconnect = true; + } a => unreachable!("Should not receive {}", a), }; Ok(()) @@ -252,7 +208,12 @@ impl EventHandlerTask { } async fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), MqttError> { - if let Some(_) = self.state.outgoing_pub.remove(&puback.packet_identifier) { + if self + .state + .outgoing_pub + .remove(&puback.packet_identifier) + .is_some() + { #[cfg(test)] debug!( "Publish {:?} has been acknowledged", @@ -263,7 +224,8 @@ impl EventHandlerTask { .mark_available(puback.packet_identifier) .await?; Ok(()) - } else { + } + else { error!( "Publish {:?} was not found, while receiving a PubAck for it", puback.packet_identifier, @@ -284,10 +246,7 @@ impl EventHandlerTask { self.network_sender.send(Packet::PubRel(pubrel)).await?; #[cfg(test)] - debug!( - "Publish {:?} has been PubReced", - pubrec.packet_identifier - ); + debug!("Publish {:?} has been PubReced", pubrec.packet_identifier); Ok(()) } _ => Ok(()), @@ -318,7 +277,8 @@ impl EventHandlerTask { .mark_available(pubcomp.packet_identifier) .await?; Ok(()) - } else { + } + else { error!( "PubRel {} was not found, while receiving a PubComp for it", pubcomp.packet_identifier, @@ -330,18 +290,20 @@ impl EventHandlerTask { } } - async fn handle_incoming_pingresp(&mut self) { - self.atomic_waiting_for_pingresp.store(false, Ordering::Release); - } - async fn handle_incoming_suback(&mut self, suback: SubAck) -> Result<(), MqttError> { - if self.state.outgoing_sub.remove(&suback.packet_identifier).is_some() { + if self + .state + .outgoing_sub + .remove(&suback.packet_identifier) + .is_some() + { self.state .apkid .mark_available(suback.packet_identifier) .await?; Ok(()) - } else { + } + else { error!( "Sub {} was not found, while receiving a SubAck for it", suback.packet_identifier, @@ -354,13 +316,19 @@ impl EventHandlerTask { } async fn handle_incoming_unsuback(&mut self, unsuback: UnsubAck) -> Result<(), MqttError> { - if self.state.outgoing_unsub.remove(&unsuback.packet_identifier).is_some() { + if self + .state + .outgoing_unsub + .remove(&unsuback.packet_identifier) + .is_some() + { self.state .apkid .mark_available(unsuback.packet_identifier) .await?; Ok(()) - } else { + } + else { error!( "Unsub {} was not found, while receiving a unsuback for it", unsuback.packet_identifier, @@ -391,8 +359,10 @@ impl EventHandlerTask { self.network_sender .send(Packet::Publish(publish.clone())) .await?; - if let Some(pub_collision) = - self.state.outgoing_pub.insert(publish.packet_identifier.unwrap(), publish) + if let Some(pub_collision) = self + .state + .outgoing_pub + .insert(publish.packet_identifier.unwrap(), publish) { error!( "Encountered a colliding packet ID ({:?}) in a publish QoS 1 packet", @@ -404,8 +374,10 @@ impl EventHandlerTask { self.network_sender .send(Packet::Publish(publish.clone())) .await?; - if let Some(pub_collision) = - self.state.outgoing_pub.insert(publish.packet_identifier.unwrap(), publish) + if let Some(pub_collision) = self + .state + .outgoing_pub + .insert(publish.packet_identifier.unwrap(), publish) { error!( "Encountered a colliding packet ID ({:?}) in a publish QoS 2 packet", @@ -418,7 +390,9 @@ impl EventHandlerTask { } async fn handle_outgoing_subscribe(&mut self, sub: Subscribe) -> Result<(), MqttError> { - if self.state.outgoing_sub + if self + .state + .outgoing_sub .insert(sub.packet_identifier, sub.clone()) .is_some() { @@ -426,14 +400,17 @@ impl EventHandlerTask { "Encountered a colliding packet ID ({}) in a subscribe packet", sub.packet_identifier, ) - } else { + } + else { self.network_sender.send(Packet::Subscribe(sub)).await?; } Ok(()) } async fn handle_outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), MqttError> { - if self.state.outgoing_unsub + if self + .state + .outgoing_unsub .insert(unsub.packet_identifier, unsub.clone()) .is_some() { @@ -441,13 +418,17 @@ impl EventHandlerTask { "Encountered a colliding packet ID ({}) in a unsubscribe packet", unsub.packet_identifier, ) - } else { + } + else { self.network_sender.send(Packet::Unsubscribe(unsub)).await?; } Ok(()) } - async fn handle_outgoing_disconnect(&mut self, disconnect: Disconnect) -> Result<(), MqttError> { + async fn handle_outgoing_disconnect( + &mut self, + disconnect: Disconnect, + ) -> Result<(), MqttError> { // self.atomic_disconnect.store(true, Ordering::Release); self.disconnect = true; self.network_sender @@ -457,372 +438,356 @@ impl EventHandlerTask { } } -pub trait AsyncEventHandler: Sized + Sync + 'static { - fn handle<'a>(&'a mut self, event: &'a Packet) -> impl Future + Send + 'a; -} - -pub trait EventHandler: Sized { - fn handle<'a>(&'a mut self, event: &'a Packet); -} - #[cfg(test)] mod handler_tests { - use std::{sync::Arc, time::Duration}; - - use async_channel::{Receiver, Sender}; - use async_mutex::Mutex; - - use crate::{ - connect_options::ConnectOptions, - event_handler::{AsyncEventHandler, EventHandlerTask}, - packets::{ - {Packet, PacketType}, - {PubComp, PubCompProperties}, - {PubRec, PubRecProperties}, - {PubRel, PubRelProperties}, - reason_codes::{ - PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode, - }, - {SubAck, SubAckProperties}, - QoS, - }, - tests::resources::test_packets::{ - create_disconnect_packet, create_puback_packet, create_publish_packet, - create_subscribe_packet, - }, - }; - - pub struct Nop {} - impl AsyncEventHandler for Nop { - fn handle<'a>( - &'a mut self, - _event: &'a Packet, - ) -> impl core::future::Future + Send + 'a { - async move { () } - } - } - - fn handler() -> ( - EventHandlerTask, - Receiver, - Sender, - Sender, - ) { - let opt = ConnectOptions::new("127.0.0.1".to_string(), 1883, "test123123".to_string()); - - let (to_network_s, to_network_r) = async_channel::bounded(100); - let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100); - let (client_to_handler_s, client_to_handler_r) = async_channel::bounded(100); - - let (handler, _apkid) = EventHandlerTask::new( - &opt, - network_to_handler_r, - to_network_s, - client_to_handler_r, - Arc::new(Mutex::new( - std::time::Instant::now() + Duration::new(600, 0), - )), - ); - ( - handler, - to_network_r, - network_to_handler_s, - client_to_handler_s, - ) - } - - #[tokio::test(flavor = "multi_thread")] - async fn outgoing_publish_qos_0() { - let mut nop = Nop {}; - - let (mut handler, to_network_r, _network_to_handler_s, client_to_handler_s) = handler(); - - let handler_task = tokio::task::spawn(async move { - let _ = loop{ - match handler.handle(&mut nop).await{ - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - let pub_packet = create_publish_packet(QoS::AtMostOnce, false, false, None); - - client_to_handler_s.send(pub_packet.clone()).await.unwrap(); - - let packet = to_network_r.recv().await.unwrap(); + // use std::{sync::Arc, time::Duration}; + + // use async_channel::{Receiver, Sender}; + // use async_mutex::Mutex; + + // use crate::{ + // connect_options::ConnectOptions, + // event_handler::{EventHandlerTask}, + // packets::{ + // reason_codes::{ + // PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode, + // }, + // QoS, {Packet, PacketType}, {PubComp, PubCompProperties}, {PubRec, PubRecProperties}, + // {PubRel, PubRelProperties}, {SubAck, SubAckProperties}, + // }, + // tests::resources::test_packets::{ + // create_disconnect_packet, create_puback_packet, create_publish_packet, + // create_subscribe_packet, + // }, + // }; + + // pub struct Nop {} + // // impl AsyncEventHandler for Nop { + // // fn handle<'a>( + // // &'a mut self, + // // _event: &'a Packet, + // // ) -> impl core::future::Future + Send + 'a { + // // async move { () } + // // } + // // } + + // fn handler() -> ( + // EventHandlerTask, + // Receiver, + // Sender, + // Sender, + // ) { + // let opt = ConnectOptions::new("test123123".to_string()); + + // let (to_network_s, to_network_r) = async_channel::bounded(100); + // let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100); + // let (client_to_handler_s, client_to_handler_r) = async_channel::bounded(100); + + // let (handler, _apkid) = EventHandlerTask::new( + // &opt, + // network_to_handler_r, + // to_network_s, + // client_to_handler_r, + // ); + // ( + // handler, + // to_network_r, + // network_to_handler_s, + // client_to_handler_s, + // ) + // } + + // #[tokio::test(flavor = "multi_thread")] + // async fn outgoing_publish_qos_0() { + // let mut nop = Nop {}; + + // let (mut handler, to_network_r, _network_to_handler_s, client_to_handler_s) = handler(); + + // let handler_task = tokio::task::spawn(async move { + // let _ = loop { + // match handler.handle(&mut nop).await { + // Ok(_) => (), + // Err(_) => break, + // } + // }; + // return handler; + // }); + // let pub_packet = create_publish_packet(QoS::AtMostOnce, false, false, None); + + // client_to_handler_s.send(pub_packet.clone()).await.unwrap(); + + // let packet = to_network_r.recv().await.unwrap(); + + // assert_eq!(packet, pub_packet); + + // // If we drop the client to handler channel the handler will stop executing and we can inspect its internals. + // drop(client_to_handler_s); + + // let handler = handler_task.await.unwrap(); + + // assert!(handler.state.incoming_pub.is_empty()); + // assert!(handler.state.outgoing_pub.is_empty()); + // assert!(handler.state.outgoing_rel.is_empty()); + // assert!(handler.state.outgoing_sub.is_empty()); + // } + + // #[tokio::test(flavor = "multi_thread")] + // async fn outgoing_publish_qos_1() { + // pub struct TestPubQoS1 { + // stage: StagePubQoS1, + // } + // pub enum StagePubQoS1 { + // PubAck, + // Done, + // } + // impl TestPubQoS1 { + // fn new() -> Self { + // TestPubQoS1 { + // stage: StagePubQoS1::PubAck, + // } + // } + // } + // // impl AsyncEventHandler for TestPubQoS1 { + // // fn handle<'a>( + // // &'a mut self, + // // event: &'a Packet, + // // ) -> impl core::future::Future + Send + 'a { + // // async move { + // // match self.stage { + // // StagePubQoS1::PubAck => { + // // assert_eq!(event.packet_type(), PacketType::PubAck); + // // self.stage = StagePubQoS1::Done; + // // } + // // StagePubQoS1::Done => (), + // // } + // // } + // // } + // // } + + // let mut nop = TestPubQoS1::new(); + + // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); + + // let handler_task = tokio::task::spawn(async move { + // // Ignore the error that this will return + // let _ = loop { + // match handler.handle(&mut nop).await { + // Ok(_) => (), + // Err(_) => break, + // } + // }; + // return handler; + // }); + // let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); + + // client_to_handler_s.send(pub_packet.clone()).await.unwrap(); + + // let publish = to_network_r.recv().await.unwrap(); + + // assert_eq!(pub_packet, publish); + + // let puback = create_puback_packet(1); + + // network_to_handler_s.send(puback).await.unwrap(); + + // tokio::time::sleep(Duration::new(5, 0)).await; + + // // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals. + // drop(client_to_handler_s); + // drop(network_to_handler_s); + + // let handler = handler_task.await.unwrap(); + + // assert!(handler.state.incoming_pub.is_empty()); + // assert!(handler.state.outgoing_pub.is_empty()); + // assert!(handler.state.outgoing_rel.is_empty()); + // assert!(handler.state.outgoing_sub.is_empty()); + // } + + // #[tokio::test(flavor = "multi_thread")] + // async fn incoming_publish_qos_1() { + // let mut nop = Nop {}; + + // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); + + // let handler_task = tokio::task::spawn(async move { + // // Ignore the error that this will return + // let _ = loop { + // match handler.handle(&mut nop).await { + // Ok(_) => (), + // Err(_) => break, + // } + // }; + // return handler; + // }); + // let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); + + // network_to_handler_s.send(pub_packet.clone()).await.unwrap(); + + // let puback = to_network_r.recv().await.unwrap(); + + // assert_eq!(PacketType::PubAck, puback.packet_type()); + + // let expected_puback = create_puback_packet(1); + + // assert_eq!(expected_puback, puback); + + // // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals. + // drop(client_to_handler_s); + // drop(network_to_handler_s); + + // let handler = handler_task.await.unwrap(); + + // assert!(handler.state.incoming_pub.is_empty()); + // assert!(handler.state.outgoing_pub.is_empty()); + // assert!(handler.state.outgoing_rel.is_empty()); + // assert!(handler.state.outgoing_sub.is_empty()); + // } + + // #[tokio::test(flavor = "multi_thread")] + // async fn outgoing_publish_qos_2() { + // pub struct TestPubQoS2 { + // stage: StagePubQoS2, + // client_to_handler_s: Sender, + // } + // pub enum StagePubQoS2 { + // PubRec, + // PubComp, + // Done, + // } + // impl TestPubQoS2 { + // fn new(client_to_handler_s: Sender) -> Self { + // TestPubQoS2 { + // stage: StagePubQoS2::PubRec, + // client_to_handler_s, + // } + // } + // } + // // impl AsyncEventHandler for TestPubQoS2 { + // // fn handle<'a>( + // // &'a mut self, + // // event: &'a Packet, + // // ) -> impl core::future::Future + Send + 'a { + // // async move { + // // match self.stage { + // // StagePubQoS2::PubRec => { + // // assert_eq!(event.packet_type(), PacketType::PubRec); + // // self.stage = StagePubQoS2::PubComp; + // // } + // // StagePubQoS2::PubComp => { + // // assert_eq!(event.packet_type(), PacketType::PubComp); + // // self.stage = StagePubQoS2::Done; + // // self.client_to_handler_s + // // .send(create_disconnect_packet()) + // // .await + // // .unwrap(); + // // } + // // StagePubQoS2::Done => (), + // // } + // // } + // // } + // // } + + // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); + + // let mut nop = TestPubQoS2::new(client_to_handler_s.clone()); + + // let handler_task = tokio::task::spawn(async move { + // let _ = loop { + // match handler.handle(&mut nop).await { + // Ok(_) => (), + // Err(_) => break, + // } + // }; + // return handler; + // }); + // let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); + + // client_to_handler_s.send(pub_packet.clone()).await.unwrap(); + + // let publish = to_network_r.recv().await.unwrap(); + + // assert_eq!(pub_packet, publish); + + // let pubrec = Packet::PubRec(PubRec { + // packet_identifier: 1, + // reason_code: PubRecReasonCode::Success, + // properties: PubRecProperties::default(), + // }); + + // network_to_handler_s.send(pubrec).await.unwrap(); + + // let packet = to_network_r.recv().await.unwrap(); + + // let expected_pubrel = Packet::PubRel(PubRel { + // packet_identifier: 1, + // reason_code: PubRelReasonCode::Success, + // properties: PubRelProperties::default(), + // }); + + // assert_eq!(expected_pubrel, packet); + + // let pubcomp = Packet::PubComp(PubComp { + // packet_identifier: 1, + // reason_code: PubCompReasonCode::Success, + // properties: PubCompProperties::default(), + // }); + + // network_to_handler_s.send(pubcomp).await.unwrap(); + + // drop(client_to_handler_s); + // drop(network_to_handler_s); + + // let handler = handler_task.await.unwrap(); + + // assert!(handler.state.incoming_pub.is_empty()); + // assert!(handler.state.outgoing_pub.is_empty()); + // assert!(handler.state.outgoing_rel.is_empty()); + // assert!(handler.state.outgoing_sub.is_empty()); + // } + + // #[tokio::test(flavor = "multi_thread")] + // async fn outgoing_subscribe() { + // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); + + // let mut nop = Nop {}; + + // let handler_task = tokio::task::spawn(async move { + // // Ignore the error that this will return + // let _ = loop { + // match handler.handle(&mut nop).await { + // Ok(_) => (), + // Err(_) => break, + // } + // }; + // return handler; + // }); + + // let sub_packet = create_subscribe_packet(1); + + // client_to_handler_s.send(sub_packet.clone()).await.unwrap(); + + // let sub_result = to_network_r.recv().await.unwrap(); + + // assert_eq!(sub_packet, sub_result); + + // let suback = Packet::SubAck(SubAck { + // packet_identifier: 1, + // reason_codes: vec![SubAckReasonCode::GrantedQoS0], + // properties: SubAckProperties::default(), + // }); - assert_eq!(packet, pub_packet); + // network_to_handler_s.send(suback).await.unwrap(); - // If we drop the client to handler channel the handler will stop executing and we can inspect its internals. - drop(client_to_handler_s); + // tokio::time::sleep(Duration::new(2, 0)).await; - let handler = handler_task.await.unwrap(); + // drop(client_to_handler_s); + // drop(network_to_handler_s); - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn outgoing_publish_qos_1() { - pub struct TestPubQoS1 { - stage: StagePubQoS1, - } - pub enum StagePubQoS1 { - PubAck, - Done, - } - impl TestPubQoS1 { - fn new() -> Self { - TestPubQoS1 { - stage: StagePubQoS1::PubAck, - } - } - } - impl AsyncEventHandler for TestPubQoS1 { - fn handle<'a>( - &'a mut self, - event: &'a Packet, - ) -> impl core::future::Future + Send + 'a { - async move { - match self.stage { - StagePubQoS1::PubAck => { - assert_eq!(event.packet_type(), PacketType::PubAck); - self.stage = StagePubQoS1::Done; - } - StagePubQoS1::Done => (), - } - } - } - } - - let mut nop = TestPubQoS1::new(); - - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); - - let handler_task = tokio::task::spawn(async move { - // Ignore the error that this will return - let _ = loop{ - match handler.handle(&mut nop).await{ - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); - - client_to_handler_s.send(pub_packet.clone()).await.unwrap(); - - let publish = to_network_r.recv().await.unwrap(); - - assert_eq!(pub_packet, publish); - - let puback = create_puback_packet(1); - - network_to_handler_s.send(puback).await.unwrap(); - - tokio::time::sleep(Duration::new(5, 0)).await; - - // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals. - drop(client_to_handler_s); - drop(network_to_handler_s); - - let handler = handler_task.await.unwrap(); - - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn incoming_publish_qos_1(){ - - let mut nop = Nop{}; - - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); - - let handler_task = tokio::task::spawn(async move { - // Ignore the error that this will return - let _ = loop{ - match handler.handle(&mut nop).await{ - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); - - network_to_handler_s.send(pub_packet.clone()).await.unwrap(); - - let puback = to_network_r.recv().await.unwrap(); - - assert_eq!(PacketType::PubAck, puback.packet_type()); - - let expected_puback = create_puback_packet(1); - - assert_eq!(expected_puback, puback); - - // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals. - drop(client_to_handler_s); - drop(network_to_handler_s); - - let handler = handler_task.await.unwrap(); - - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn outgoing_publish_qos_2() { - pub struct TestPubQoS2 { - stage: StagePubQoS2, - client_to_handler_s: Sender, - } - pub enum StagePubQoS2 { - PubRec, - PubComp, - Done, - } - impl TestPubQoS2 { - fn new(client_to_handler_s: Sender) -> Self { - TestPubQoS2 { - stage: StagePubQoS2::PubRec, - client_to_handler_s, - } - } - } - impl AsyncEventHandler for TestPubQoS2 { - fn handle<'a>( - &'a mut self, - event: &'a Packet, - ) -> impl core::future::Future + Send + 'a { - async move { - match self.stage { - StagePubQoS2::PubRec => { - assert_eq!(event.packet_type(), PacketType::PubRec); - self.stage = StagePubQoS2::PubComp; - } - StagePubQoS2::PubComp => { - assert_eq!(event.packet_type(), PacketType::PubComp); - self.stage = StagePubQoS2::Done; - self.client_to_handler_s - .send(create_disconnect_packet()) - .await - .unwrap(); - } - StagePubQoS2::Done => (), - } - } - } - } - - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); - - let mut nop = TestPubQoS2::new(client_to_handler_s.clone()); - - let handler_task = tokio::task::spawn(async move { - let _ = loop{ - match handler.handle(&mut nop).await{ - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1)); - - client_to_handler_s.send(pub_packet.clone()).await.unwrap(); - - let publish = to_network_r.recv().await.unwrap(); - - assert_eq!(pub_packet, publish); - - let pubrec = Packet::PubRec(PubRec { - packet_identifier: 1, - reason_code: PubRecReasonCode::Success, - properties: PubRecProperties::default(), - }); - - network_to_handler_s.send(pubrec).await.unwrap(); - - let packet = to_network_r.recv().await.unwrap(); - - let expected_pubrel = Packet::PubRel(PubRel { - packet_identifier: 1, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties::default(), - }); - - assert_eq!(expected_pubrel, packet); - - let pubcomp = Packet::PubComp(PubComp { - packet_identifier: 1, - reason_code: PubCompReasonCode::Success, - properties: PubCompProperties::default(), - }); - - network_to_handler_s.send(pubcomp).await.unwrap(); - - drop(client_to_handler_s); - drop(network_to_handler_s); - - let handler = handler_task.await.unwrap(); - - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); - } - - #[tokio::test(flavor = "multi_thread")] - async fn outgoing_subscribe() { - let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler(); - - let mut nop = Nop {}; - - let handler_task = tokio::task::spawn(async move { - // Ignore the error that this will return - let _ = loop{ - match handler.handle(&mut nop).await{ - Ok(_) => (), - Err(_) => break, - } - }; - return handler; - }); - - let sub_packet = create_subscribe_packet(1); - - client_to_handler_s.send(sub_packet.clone()).await.unwrap(); - - let sub_result = to_network_r.recv().await.unwrap(); - - assert_eq!(sub_packet, sub_result); - - let suback = Packet::SubAck(SubAck { - packet_identifier: 1, - reason_codes: vec![SubAckReasonCode::GrantedQoS0], - properties: SubAckProperties::default(), - }); - - network_to_handler_s.send(suback).await.unwrap(); - - tokio::time::sleep(Duration::new(2, 0)).await; - - drop(client_to_handler_s); - drop(network_to_handler_s); - - let handler = handler_task.await.unwrap(); - assert!(handler.state.incoming_pub.is_empty()); - assert!(handler.state.outgoing_pub.is_empty()); - assert!(handler.state.outgoing_rel.is_empty()); - assert!(handler.state.outgoing_sub.is_empty()); - } + // let handler = handler_task.await.unwrap(); + // assert!(handler.state.incoming_pub.is_empty()); + // assert!(handler.state.outgoing_pub.is_empty()); + // assert!(handler.state.outgoing_rel.is_empty()); + // assert!(handler.state.outgoing_sub.is_empty()); + // } } diff --git a/src/lib.rs b/src/lib.rs index 917f47f..ba73a9e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,23 +1,27 @@ //! A pure rust MQTT client which strives to be as efficient as possible. //! This crate strives to provide an ergonomic API and design that fits Rust. -//! +//! //! There are three parts to the design of the MQTT client. The network, the event handler and the client. -//! +//! //! The network - which simply reads and forms packets from the network. //! The event handler - which makes sure that the MQTT protocol is followed. //! By providing a custom handler during the internal handling, messages are handled before they are acked. //! The client - which is used to send messages from different places. -//! +//! //! To Do: //! - Rebroadcast unacked packets. -//! - Keep alive sending of PingReq and PingResp. //! - Enforce size of outbound messages (e.g. Publish) -//! +//! - Sync API +//! //! A few questions still remain: //! - This crate uses async channels to perform communication across its parts. Is there a better approach? -//! These channels do allow the user to decouple this crate and the network in the future but comes at a cost of more copies -//! - This crate provides network implementation which hinder sync and async agnosticism. -//! +//! These channels do allow the user to decouple the network, handlers, and clients very easily. +//! - This crate provides network implementation which hinder syn, counter: 0c and async agnosticism. +//! Would a true sansio implementation be better? +//! At first this crate used custom async traits which are not stable (async_fn_in_trait). +//! The current version allows the user to provide the appropriate stream. +//! This also nicely relives us from having to deal with TLS configuration. +//! //! For the future it could be nice to be sync, async and runtime agnostic. //! This can be achieved by decoupling the MQTT internals from the network communication. //! The user could provide the received packets while this crate returns the response packets. @@ -27,7 +31,7 @@ //! //! Tokio example: //! ---------------------------- -//! ```no_run +//! ```ignore //! let config = RustlsConfig::Simple { //! ca: EMQX_CERT.to_vec(), //! alpn: None, @@ -35,85 +39,101 @@ //! }; //! let opt = ConnectOptions::new("broker.emqx.io".to_string(), 8883, "test123123".to_string()); //! let (mqtt_network, handler, client) = create_tokio_rustls(opt, config); -//! +//! //! task::spawn(async move { //! join!(mqtt_network.run(), handler.handle(/* Custom handler */)); //! }); -//! +//! //! for i in 0..10 { //! client.publish("test", QoS::AtLeastOnce, false, b"test payload").await.unwrap(); //! time::sleep(Duration::from_millis(100)).await; //! } //! ``` -//! +//! //! Smol example: -//! //! ``` +//! use mqrstt::{ +//! client::AsyncClient, +//! connect_options::ConnectOptions, +//! new_smol, +//! packets::{self, Packet}, +//! AsyncEventHandlerMut, HandlerStatus, NetworkStatus, +//! }; +//! use async_trait::async_trait; +//! use bytes::Bytes; //! pub struct PingPong { -//! pub client: AsyncClient, +//! pub client: AsyncClient, //! } -//! -//! impl EventHandler for PingPong{ -//! fn handle<'a>(&'a mut self, event: &'a packets::Packet) -> impl std::future::Future + Send + 'a { -//! async move{ -//! match event{ -//! Packet::Publish(p) => { -//! if let Ok(payload) = String::from_utf8(p.payload.to_vec()){ -//! if payload.to_lowercase().contains("ping"){ -//! self.client.publish(p.qos, p.retain, p.topic.clone(), Bytes::from_static(b"pong")).await; -//! println!("Received Ping, Send pong!"); -//! } +//! #[async_trait] +//! impl AsyncEventHandlerMut for PingPong { +//! async fn handle(&mut self, event: &packets::Packet) -> () { +//! match event { +//! Packet::Publish(p) => { +//! if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { +//! if payload.to_lowercase().contains("ping") { +//! self.client +//! .publish( +//! p.qos, +//! p.retain, +//! p.topic.clone(), +//! Bytes::from_static(b"pong"), +//! ) +//! .await +//! .unwrap(); +//! println!("Received Ping, Send pong!"); //! } -//! }, -//! Packet::ConnAck(_) => { -//! println!("Connected!"); //! } -//! _ => (), -//! } +//! }, +//! Packet::ConnAck(_) => { println!("Connected!") }, +//! _ => (), //! } //! } //! } -//! -//! fn main(){ -//! let options = ConnectOptions::new("broker.emqx.io".to_string(), 8883, "mqrstt".to_string()); -//! -//! let tls_config = RustlsConfig::Simple { -//! ca: crate::tests::resources::EMQX_CERT.to_vec(), -//! alpn: None, -//! client_auth: None, +//! smol::block_on(async { +//! let options = ConnectOptions::new("mqrstt".to_string()); +//! let (mut network, mut handler, client) = new_smol(options); +//! let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) +//! .await +//! .unwrap(); +//! network.connect(stream).await.unwrap(); +//! client.subscribe("mqrstt").await.unwrap(); +//! let mut pingpong = PingPong { +//! client: client.clone(), //! }; -//! -//! let (mut network, handler, client) = create_smol_rustls(options, tls_config); -//! smol::block_on(async{ -//! client.subscribe("mqrstt").await.unwrap(); -//! -//! let mut pingpong = PingPong{ client }; -//! -//! join!(network.run(), handler.handle(&mut pingpong)); -//! }); -//! } +//! let (n, h, t) = futures::join!( +//! async { +//! loop { +//! return match network.run().await { +//! Ok(NetworkStatus::Active) => continue, +//! otherwise => otherwise, +//! }; +//! } +//! }, +//! async { +//! loop { +//! return match handler.handle_mut(&mut pingpong).await { +//! Ok(HandlerStatus::Active) => continue, +//! otherwise => otherwise, +//! }; +//! } +//! }, +//! async { +//! smol::Timer::after(std::time::Duration::from_secs(60)).await; +//! client.disconnect().await.unwrap(); +//! } +//! ); +//! assert!(n.is_ok()); +//! assert!(h.is_ok()); +//! }); //! ``` -//! - -#![feature(async_fn_in_trait)] -#![feature(return_position_impl_trait_in_trait)] - -use std::{sync::Arc, time::Instant}; - -use async_mutex::Mutex; use client::AsyncClient; use connect_options::ConnectOptions; -use connections::*; - -// #[cfg(all(feature = "tokio", feature = "tcp"))] -// use connections::tokio_tcp::{TcpReader, TcpWriter}; - -use connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite}; - use event_handler::EventHandlerTask; -use network::MqttNetwork; + +use packets::Packet; +use smol_network::SmolNetwork; mod available_packet_ids; pub mod client; @@ -121,69 +141,108 @@ pub mod connect_options; pub mod connections; pub mod error; pub mod event_handler; -pub mod network; pub mod packets; +pub mod smol_network; pub mod state; mod util; #[cfg(test)] pub mod tests; +pub mod tokio_network; + +/// [`NetworkStatus`] Represents status of the Network object. +/// It is returned when the run handle returns from performing an operation. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum NetworkStatus { + Active, + IncomingDisconnect, + OutgoingDisconnect, + NoPingResp, +} -#[cfg(all(feature = "smol", feature = "smol-rustls"))] -pub fn create_smol_rustls( - mut options: ConnectOptions, - tls_config: transport::RustlsConfig, -) -> ( - MqttNetwork, - EventHandlerTask, - AsyncClient, -) { +/// [`HandlerStatus`] Represents status of the Network object. +/// It is returned when the run handle returns from performing an operation. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum HandlerStatus { + Active, + IncomingDisconnect, + OutgoingDisconnect, +} - options.tls_config = Some(transport::TlsConfig::Rustls(tls_config)); - new(options) +#[async_trait::async_trait] +/// Handlers are used to deal with packets before they are further processed (acked) +/// This guarantees that the end user has handlded the packet. +/// Trait for async mutable access to handler. +/// Usefull when you have a single handler +pub trait AsyncEventHandlerMut { + async fn handle(&mut self, event: &Packet); } -#[cfg(all(feature = "tokio", feature = "tokio-rustls"))] -pub fn create_tokio_rustls( - mut options: ConnectOptions, - tls_config: transport::RustlsConfig, -) -> ( - MqttNetwork, - EventHandlerTask, - AsyncClient, -) { +#[async_trait::async_trait] +/// Handlers are used to deal with packets before they are further processed (acked) +/// This guarantees that the end user has handlded the packet. +/// Trait for async immutable access to handler. +/// Usefull when you want to run multiple handlers concurrently to increase throughput. +pub trait AsyncEventHandler { + async fn handle(&self, event: &Packet); +} - options.tls_config = Some(transport::TlsConfig::Rustls(tls_config)); - new(options) +/// Handlers are used to deal with packets before they are further processed (acked) +/// This guarantees that the end user has handlded the packet. +/// Trait for sync mutable access to handler. +/// Usefull when you want to run multiple handlers concurrently to increase throughput. +pub trait EventHandlerMut { + fn handle(&mut self, event: &Packet); } -#[cfg(all(feature = "tokio", feature = "tcp"))] -pub fn create_tokio_tcp( - options: ConnectOptions, -) -> ( - MqttNetwork, - EventHandlerTask, - AsyncClient, -) { - new(options) +/// Handlers are used to deal with packets before they are further processed (acked) +/// This guarantees that the end user has handlded the packet. +/// Trait for sync immutable access to handler. +/// Usefull when you want to run multiple handlers concurrently to increase throughput. +pub trait EventHandler { + fn handle(&self, event: &Packet); +} + +// #[cfg(all(feature = "smol", feature = "tokio"))] +// std::compile_error!("The features smol and tokio can not be enabled simultaiously."); + +#[cfg(feature = "smol")] +/// Creates the needed components to run the MQTT client using a stream that implements [`smol::io::AsyncReadExt`] and [`smol::io::AsyncWriteExt`] +pub fn new_smol(options: ConnectOptions) -> (SmolNetwork, EventHandlerTask, AsyncClient) +where + S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, { + let receive_maximum = options.receive_maximum(); + + let (to_network_s, to_network_r) = async_channel::bounded(100); + let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100); + let (client_to_handler_s, client_to_handler_r) = + async_channel::bounded(receive_maximum as usize); + + let (handler, packet_ids) = EventHandlerTask::new( + &options, + network_to_handler_r, + to_network_s.clone(), + client_to_handler_r, + ); + + let network = SmolNetwork::::new(options, network_to_handler_s, to_network_r); + + let client = AsyncClient::new(packet_ids, client_to_handler_s, to_network_s); + + (network, handler, client) } -#[cfg(all(feature = "smol", feature = "tcp"))] -pub fn create_smol_tcp( +#[cfg(feature = "tokio")] +/// Creates the needed components to run the MQTT client using a stream that implements [`tokio::io::AsyncReadExt`] and [`tokio::io::AsyncWriteExt`] +pub fn new_tokio( options: ConnectOptions, ) -> ( - MqttNetwork, + tokio_network::TokioNetwork, EventHandlerTask, AsyncClient, -) { - new(options) -} - -pub fn new(options: ConnectOptions) -> (MqttNetwork, EventHandlerTask, AsyncClient) +) where - R: AsyncMqttNetworkRead, - W: AsyncMqttNetworkWrite, -{ + S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin, { let receive_maximum = options.receive_maximum(); let (to_network_s, to_network_r) = async_channel::bounded(100); @@ -191,103 +250,282 @@ where let (client_to_handler_s, client_to_handler_r) = async_channel::bounded(receive_maximum as usize); - let last_network_action = Arc::new(Mutex::new(Instant::now())); - let (handler, packet_ids) = EventHandlerTask::new( &options, network_to_handler_r, to_network_s.clone(), - client_to_handler_r.clone(), - last_network_action.clone(), + client_to_handler_r, ); - let network = MqttNetwork::::new( - options, - network_to_handler_s, - to_network_r, - last_network_action, - ); + let network = + tokio_network::TokioNetwork::::new(options, network_to_handler_s, to_network_r); let client = AsyncClient::new(packet_ids, client_to_handler_s, to_network_s); (network, handler, client) } - -#[cfg(all(test, any(feature = "smol-rustls", feature = "tokio-rustls")))] -mod lib_test{ +#[cfg(test)] +mod lib_test { + use std::time::Duration; + + use crate::{ + client::AsyncClient, + connect_options::ConnectOptions, + new_smol, new_tokio, + packets::{self, Packet}, + util::tls::tests::simple_rust_tls, + AsyncEventHandlerMut, HandlerStatus, NetworkStatus, + }; + use async_trait::async_trait; use bytes::Bytes; - use futures::join; - - use crate::{client::AsyncClient, packets::{self, Packet}}; - use crate::event_handler::AsyncEventHandler; - use crate::create_smol_rustls; - use crate::connections::transport::RustlsConfig; - use crate::connect_options::ConnectOptions; + use rustls::ServerName; pub struct PingPong { pub client: AsyncClient, } - - impl AsyncEventHandler for PingPong{ - fn handle<'a>(&'a mut self, event: &'a packets::Packet) -> impl std::future::Future + Send + 'a { - async move{ - match event{ - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()){ - if payload.to_lowercase().contains("ping"){ - self.client.publish(p.qos, p.retain, p.topic.clone(), Bytes::from_static(b"pong")).await; - println!("Received Ping, Send pong!"); - } + + #[async_trait] + impl AsyncEventHandlerMut for PingPong { + async fn handle(&mut self, event: &packets::Packet) -> () { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + if payload.to_lowercase().contains("ping") { + self.client + .publish( + p.qos, + p.retain, + p.topic.clone(), + Bytes::from_static(b"pong"), + ) + .await + .unwrap(); + println!("Received Ping, Send pong!"); } - }, - Packet::ConnAck(_) => { - println!("Connected!"); } - _ => (), } + _ => (), } } } - - // #[test] - fn test_smol(){ + #[test] + fn test_smol_tcp() { + smol::block_on(async { + let options = ConnectOptions::new("SmolTcpPingPong".to_string()); - let filter = tracing_subscriber::filter::EnvFilter::new("none,mqrstt=trace"); + let address = "broker.emqx.io"; + let port = 1883; - let subscriber = tracing_subscriber::FmtSubscriber::builder() - .with_env_filter(filter) - .with_max_level(tracing::Level::TRACE) - .with_line_number(true) - .finish(); + let (mut network, mut handler, client) = new_smol(options); - tracing::subscriber::set_global_default(subscriber) - .expect("setting default subscriber failed"); + let stream = smol::net::TcpStream::connect((address, port)) + .await + .unwrap(); - smol::block_on(async{ - let options = ConnectOptions::new("broker.emqx.io".to_string(), 8883, "mqrstt".to_string()); - - let tls_config = RustlsConfig::Simple { - ca: crate::tests::resources::EMQX_CERT.to_vec(), - alpn: None, - client_auth: None, + network.connect(stream).await.unwrap(); + + client.subscribe("mqrstt").await.unwrap(); + + let mut pingpong = PingPong { + client: client.clone(), }; - - let (mut network, mut handler, client) = create_smol_rustls(options, tls_config); - + + let (n, h, _) = futures::join!( + async { + loop { + return match network.run().await { + Ok(NetworkStatus::Active) => continue, + otherwise => otherwise, + }; + } + }, + async { + loop { + return match handler.handle_mut(&mut pingpong).await { + Ok(HandlerStatus::Active) => continue, + otherwise => otherwise, + }; + } + }, + async { + smol::Timer::after(std::time::Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); + } + ); + assert!(n.is_ok()); + assert!(h.is_ok()); + }); + } + + #[test] + fn test_smol_tls() { + smol::block_on(async { + let options = ConnectOptions::new("SmolTlsPingPong".to_string()); + + let address = "broker.emqx.io"; + let port = 8883; + + let (mut network, mut handler, client) = new_smol(options); + + let arc_client_config = + simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap(); + + let domain = ServerName::try_from(address).unwrap(); + let connector = async_rustls::TlsConnector::from(arc_client_config); + + let stream = smol::net::TcpStream::connect((address, port)) + .await + .unwrap(); + let connection = connector.connect(domain, stream).await.unwrap(); + + network.connect(connection).await.unwrap(); + client.subscribe("mqrstt").await.unwrap(); - - - let mut pingpong = PingPong{ client }; - join!(network.run(), + let mut pingpong = PingPong { + client: client.clone(), + }; + + let (n, h, _) = futures::join!( async { - loop{ - handler.handle(&mut pingpong).await.unwrap(); + loop { + return match network.run().await { + Ok(NetworkStatus::Active) => continue, + otherwise => otherwise, + }; } + }, + async { + loop { + return match handler.handle_mut(&mut pingpong).await { + Ok(HandlerStatus::Active) => continue, + otherwise => otherwise, + }; + } + }, + async { + smol::Timer::after(std::time::Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); } - ).0.unwrap(); + ); + assert!(n.is_ok()); + assert!(h.is_ok()); }); } -} \ No newline at end of file + + #[tokio::test] + async fn test_tokio_tcp() { + let options = ConnectOptions::new("TokioTcpPingPong".to_string()); + + let (mut network, mut handler, client) = new_tokio(options); + + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) + .await + .unwrap(); + + network.connect(stream).await.unwrap(); + + client.subscribe("mqrstt").await.unwrap(); + + let mut pingpong = PingPong { + client: client.clone(), + }; + + let (n, h, _) = tokio::join!( + async { + loop { + return match network.run().await { + Ok(NetworkStatus::Active) => continue, + otherwise => otherwise, + }; + } + }, + async { + loop { + return match handler.handle_mut(&mut pingpong).await { + Ok(HandlerStatus::Active) => continue, + otherwise => otherwise, + }; + } + }, + async { + tokio::time::sleep(Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); + } + ); + assert!(n.is_ok()); + assert!(h.is_ok()); + + assert_eq!(NetworkStatus::OutgoingDisconnect, n.unwrap()); + assert_eq!(HandlerStatus::OutgoingDisconnect, h.unwrap()); + } + + #[tokio::test] + async fn test_tokio_tls() { + let options = ConnectOptions::new("TokioTlsPingPong".to_string()); + + let address = "broker.emqx.io"; + let port = 8883; + + let (mut network, mut handler, client) = new_tokio(options); + + let arc_client_config = + simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap(); + + let domain = ServerName::try_from(address).unwrap(); + let connector = tokio_rustls::TlsConnector::from(arc_client_config); + + let stream = tokio::net::TcpStream::connect((address, port)) + .await + .unwrap(); + let connection = connector.connect(domain, stream).await.unwrap(); + + network.connect(connection).await.unwrap(); + + client.subscribe("mqrstt").await.unwrap(); + + let mut pingpong = PingPong { + client: client.clone(), + }; + + let (n, h, _) = tokio::join!( + async { + loop { + return match network.run().await { + Ok(NetworkStatus::IncomingDisconnect) => { + Ok(NetworkStatus::IncomingDisconnect) + } + Ok(NetworkStatus::OutgoingDisconnect) => { + Ok(NetworkStatus::OutgoingDisconnect) + } + Ok(NetworkStatus::NoPingResp) => Ok(NetworkStatus::NoPingResp), + Ok(NetworkStatus::Active) => continue, + Err(a) => Err(a), + }; + } + }, + async { + loop { + return match handler.handle_mut(&mut pingpong).await { + Ok(HandlerStatus::IncomingDisconnect) => { + Ok(NetworkStatus::IncomingDisconnect) + } + Ok(HandlerStatus::OutgoingDisconnect) => { + Ok(NetworkStatus::OutgoingDisconnect) + } + Ok(HandlerStatus::Active) => continue, + Err(a) => Err(a), + }; + } + }, + async { + tokio::time::sleep(Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); + } + ); + assert!(n.is_ok()); + assert!(h.is_ok()); + } +} diff --git a/src/network.rs b/src/network.rs deleted file mode 100644 index 81a3510..0000000 --- a/src/network.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; - -use async_channel::{Receiver, Sender}; -use async_mutex::Mutex; - -use futures_concurrency::future::Join; -use std::time::Instant; -use tracing::trace; - -use crate::connect_options::ConnectOptions; -use crate::connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite}; -use crate::error::ConnectionError; -use crate::packets::Packet; -use crate::util::timeout::{timeout, self}; - -pub type Incoming = Packet; -pub type Outgoing = Packet; - -pub struct MqttNetwork { - network: Option<(R, W)>, - - // write_buffer: BytesMut, - /// Options of the current mqtt connection - options: ConnectOptions, - - last_network_action: Arc>, - - network_to_handler_s: Sender, - // incoming_packet_receiver: Receiver, - // outgoing_packet_sender: Sender, - to_network_r: Receiver, -} - -impl MqttNetwork -where - R: AsyncMqttNetworkRead, - W: AsyncMqttNetworkWrite, -{ - pub fn new( - options: ConnectOptions, - network_to_handler_s: Sender, - to_network_r: Receiver, - last_network_action: Arc>, - ) -> Self { - Self { - network: None, - - options, - - last_network_action, - - network_to_handler_s, - to_network_r, - } - } - - pub fn reset(&mut self) { - self.network = None; - } - - pub async fn run(&mut self) -> Result<(), ConnectionError> { - if self.network.is_none() { - trace!("Creating network"); - - // let con = R::connect(&self.options).await; - let con = timeout::timeout(R::connect(&self.options), self.options.connection_timeout_s).await?; - - let (reader, writer, connack) = con?; - - trace!("Succesfully created network"); - self.network = Some((reader, writer)); - self.network_to_handler_s.send(connack).await?; - } - - let MqttNetwork { - network, - options: _, - last_network_action, - network_to_handler_s, - to_network_r, - } = self; - - if let Some((reader, writer)) = network { - let disconnect = AtomicBool::new(false); - - let incoming = async { - loop { - let local_disconnect = reader.read_direct(network_to_handler_s).await?; - *(last_network_action.lock().await) = std::time::Instant::now(); - if local_disconnect { - disconnect.store(true, Ordering::Release); - return Ok(()); - } else if disconnect.load(Ordering::Acquire) { - return Ok(()); - } - } - }; - - let outgoing = async { - loop { - let local_disconnect = writer.write(to_network_r).await?; - *(last_network_action.lock().await) = std::time::Instant::now(); - if local_disconnect { - disconnect.store(true, Ordering::Release); - return Ok(()); - } else if disconnect.load(Ordering::Acquire) { - return Ok(()); - } - } - }; - - let res: (Result<(), ConnectionError>, Result<(), ConnectionError>) = - (incoming, outgoing).join().await; - res.0?; - res.1 - } else { - Ok(()) - } - } -} diff --git a/src/packets/auth.rs b/src/packets/auth.rs index 86d56ba..04f47ce 100644 --- a/src/packets/auth.rs +++ b/src/packets/auth.rs @@ -67,7 +67,8 @@ impl MqttRead for AuthProperties { if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::MalformedPacket); } diff --git a/src/packets/connack.rs b/src/packets/connack.rs index 1724ce1..6427054 100644 --- a/src/packets/connack.rs +++ b/src/packets/connack.rs @@ -123,7 +123,8 @@ impl MqttRead for ConnAckProperties { let mut properties = Self::default(); if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "ConnAckProperties".to_string(), buf.len(), diff --git a/src/packets/connect.rs b/src/packets/connect.rs index 5e11ff4..297fb73 100644 --- a/src/packets/connect.rs +++ b/src/packets/connect.rs @@ -128,12 +128,14 @@ impl VariableHeaderRead for Connect { let username = if connect_flags.contains(ConnectFlags::USERNAME) { Some(String::read(&mut buf)?) - } else { + } + else { None }; let password = if connect_flags.contains(ConnectFlags::PASSWORD) { Some(String::read(&mut buf)?) - } else { + } + else { None }; @@ -351,7 +353,8 @@ impl MqttRead for ConnectProperties { let mut properties = Self::default(); if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "ConnectProperties".to_string(), buf.len(), @@ -573,7 +576,8 @@ impl MqttRead for LastWillProperties { let mut properties = Self::default(); if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "LastWillProperties".to_string(), buf.len(), diff --git a/src/packets/disconnect.rs b/src/packets/disconnect.rs index 254f6fc..908cbf0 100644 --- a/src/packets/disconnect.rs +++ b/src/packets/disconnect.rs @@ -34,7 +34,8 @@ impl VariableHeaderRead for Disconnect { if remaining_length == 0 { reason_code = DisconnectReasonCode::NormalDisconnection; properties = DisconnectProperties::default(); - } else { + } + else { reason_code = DisconnectReasonCode::read(&mut buf)?; properties = DisconnectProperties::read(&mut buf)?; } @@ -62,7 +63,8 @@ impl WireLength for Disconnect { || self.properties.wire_len() != 0 { self.properties.wire_len() + 1 - } else { + } + else { 0 } } @@ -83,7 +85,8 @@ impl MqttRead for DisconnectProperties { let mut properties = Self::default(); if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "DisconnectProperties".to_string(), buf.len(), diff --git a/src/packets/mod.rs b/src/packets/mod.rs index f12c40f..3e90c4e 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -2,14 +2,13 @@ pub mod error; pub mod mqtt_traits; pub mod reason_codes; -mod publish; mod auth; mod connack; mod connect; mod disconnect; -mod packets; mod puback; mod pubcomp; +mod publish; mod pubrec; mod pubrel; mod suback; @@ -17,28 +16,26 @@ mod subscribe; mod unsuback; mod unsubscribe; - pub use auth::*; pub use connack::*; pub use connect::*; pub use disconnect::*; pub use puback::*; pub use pubcomp::*; +pub use publish::*; pub use pubrec::*; pub use pubrel::*; pub use suback::*; pub use subscribe::*; pub use unsuback::*; pub use unsubscribe::*; -pub use publish::*; - -pub use packets::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; use core::slice::Iter; +use std::fmt::Display; use self::error::{DeserializeError, ReadBytes, SerializeError}; -use self::mqtt_traits::{MqttRead, MqttWrite, WireLength}; +use self::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; /// Protocol version #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] @@ -126,11 +123,14 @@ impl TryFrom for QoS { fn try_from(c: ConnectFlags) -> Result { if c.contains(ConnectFlags::WILL_QOS1 | ConnectFlags::WILL_QOS2) { Err(DeserializeError::MalformedPacket) - } else if c.contains(ConnectFlags::WILL_QOS2) { + } + else if c.contains(ConnectFlags::WILL_QOS2) { Ok(QoS::ExactlyOnce) - } else if c.contains(ConnectFlags::WILL_QOS1) { + } + else if c.contains(ConnectFlags::WILL_QOS1) { Ok(QoS::AtLeastOnce) - } else { + } + else { Ok(QoS::AtMostOnce) } } @@ -228,7 +228,8 @@ impl MqttWrite for bool { if *self { buf.put_u8(1); Ok(()) - } else { + } + else { buf.put_u8(0); Ok(()) } @@ -284,7 +285,8 @@ pub fn read_fixed_header_rem_len( if (*byte & 0b1000_0000) == 0 { return Ok((integer, length)); } - } else { + } + else { return Err(ReadBytes::InsufficientBytes(1)); } } @@ -335,11 +337,14 @@ pub fn write_variable_integer(buf: &mut BytesMut, integer: usize) -> Result<(), pub fn variable_integer_len(integer: usize) -> usize { if integer >= 2_097_152 { 4 - } else if integer >= 16_384 { + } + else if integer >= 16_384 { 3 - } else if integer >= 128 { + } + else if integer >= 128 { 2 - } else { + } + else { 1 } } @@ -520,3 +525,514 @@ impl PropertyType { } } } + +// ==================== Packets ==================== + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Packet { + Connect(Connect), + ConnAck(ConnAck), + Publish(Publish), + PubAck(PubAck), + PubRec(PubRec), + PubRel(PubRel), + PubComp(PubComp), + Subscribe(Subscribe), + SubAck(SubAck), + Unsubscribe(Unsubscribe), + UnsubAck(UnsubAck), + PingReq, + PingResp, + Disconnect(Disconnect), + Auth(Auth), +} + +impl Packet { + pub fn packet_type(&self) -> PacketType { + match self { + Packet::Connect(_) => PacketType::Connect, + Packet::ConnAck(_) => PacketType::ConnAck, + Packet::Publish(_) => PacketType::Publish, + Packet::PubAck(_) => PacketType::PubAck, + Packet::PubRec(_) => PacketType::PubRec, + Packet::PubRel(_) => PacketType::PubRel, + Packet::PubComp(_) => PacketType::PubComp, + Packet::Subscribe(_) => PacketType::Subscribe, + Packet::SubAck(_) => PacketType::SubAck, + Packet::Unsubscribe(_) => PacketType::Unsubscribe, + Packet::UnsubAck(_) => PacketType::UnsubAck, + Packet::PingReq => PacketType::PingReq, + Packet::PingResp => PacketType::PingResp, + Packet::Disconnect(_) => PacketType::Disconnect, + Packet::Auth(_) => PacketType::Auth, + } + } + + pub fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + match self { + Packet::Connect(p) => { + buf.put_u8(0b0001_0000); + write_variable_integer(buf, p.wire_len())?; + + p.write(buf)?; + } + Packet::ConnAck(_) => { + unreachable!() + } + Packet::Publish(p) => { + let mut first_byte = 0b0011_0000u8; + if p.dup { + first_byte |= 0b1000; + } + + first_byte |= p.qos.into_u8() << 1; + + if p.retain { + first_byte |= 0b0001; + } + buf.put_u8(first_byte); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::PubAck(p) => { + buf.put_u8(0b0100_0000); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::PubRec(p) => { + buf.put_u8(0b0101_0000); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::PubRel(p) => { + buf.put_u8(0b0110_0010); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::PubComp(p) => { + buf.put_u8(0b0111_0000); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::Subscribe(p) => { + buf.put_u8(0b1000_0010); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::SubAck(_) => { + unreachable!() + } + Packet::Unsubscribe(p) => { + buf.put_u8(0b1010_0010); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::UnsubAck(_) => { + buf.put_u8(0b1011_0000); + unreachable!() + } + Packet::PingReq => { + buf.put_u8(0b1100_0000); + buf.put_u8(0); // Variable header length. + } + Packet::PingResp => { + buf.put_u8(0b1101_0000); + buf.put_u8(0); // Variable header length. + } + Packet::Disconnect(p) => { + buf.put_u8(0b1110_0000); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + Packet::Auth(p) => { + buf.put_u8(0b1111_0000); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; + } + } + Ok(()) + } + + pub fn read(header: FixedHeader, buf: Bytes) -> Result { + let packet = match header.packet_type { + PacketType::Connect => { + Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?) + } + PacketType::ConnAck => { + Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?) + } + PacketType::Publish => { + Packet::Publish(Publish::read(header.flags, header.remaining_length, buf)?) + } + PacketType::PubAck => { + Packet::PubAck(PubAck::read(header.flags, header.remaining_length, buf)?) + } + PacketType::PubRec => { + Packet::PubRec(PubRec::read(header.flags, header.remaining_length, buf)?) + } + PacketType::PubRel => { + Packet::PubRel(PubRel::read(header.flags, header.remaining_length, buf)?) + } + PacketType::PubComp => { + Packet::PubComp(PubComp::read(header.flags, header.remaining_length, buf)?) + } + PacketType::Subscribe => { + Packet::Subscribe(Subscribe::read(header.flags, header.remaining_length, buf)?) + } + PacketType::SubAck => { + Packet::SubAck(SubAck::read(header.flags, header.remaining_length, buf)?) + } + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read( + header.flags, + header.remaining_length, + buf, + )?), + PacketType::UnsubAck => { + Packet::UnsubAck(UnsubAck::read(header.flags, header.remaining_length, buf)?) + } + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect(Disconnect::read( + header.flags, + header.remaining_length, + buf, + )?), + PacketType::Auth => { + Packet::Auth(Auth::read(header.flags, header.remaining_length, buf)?) + } + }; + Ok(packet) + } + + pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> { + let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?; + + if header.remaining_length + header_length > buffer.len() { + return Err(ReadBytes::InsufficientBytes( + header.remaining_length - buffer.len(), + )); + } + buffer.advance(header_length); + + let buf = buffer.split_to(header.remaining_length); + + Ok(Packet::read(header, buf.into())?) + } +} + +impl Display for Packet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self{ + Packet::Connect(c) => write!(f, "Connect(version: {:?}, clean: {}, username: {:?}, password: {:?}, keep_alive: {}, client_id: {})", c.protocol_version, c.clean_session, c.username, c.password, c.keep_alive, c.client_id), + Packet::ConnAck(c) => write!(f, "ConnAck(session:{:?}, reason code{:?})", c.connack_flags, c.reason_code), + Packet::Publish(p) => write!(f, "Publish(topic: {}, qos: {:?}, dup: {:?}, retain: {:?}, packet id: {:?})", &p.topic, p.qos, p.dup, p.retain, p.packet_identifier), + Packet::PubAck(p) => write!(f, "PubAck(id:{:?}, reason code: {:?})", p.packet_identifier, p.reason_code), + Packet::PubRec(p) => write!(f, "PubRec(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), + Packet::PubRel(p) => write!(f, "PubRel(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), + Packet::PubComp(p) => write!(f, "PubComp(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), + Packet::Subscribe(_) => write!(f, "Subscribe()"), + Packet::SubAck(_) => write!(f, "SubAck()"), + Packet::Unsubscribe(_) => write!(f, "Unsubscribe()"), + Packet::UnsubAck(_) => write!(f, "UnsubAck()"), + Packet::PingReq => write!(f, "PingReq"), + Packet::PingResp => write!(f, "PingResp"), + Packet::Disconnect(d) => write!(f, "Disconnect(reason code: {:?})", d.reason_code), + Packet::Auth(_) => write!(f, "Auth()"), + } + } +} + +// 2.1.1 Fixed Header +// ``` +// 7 3 0 +// +--------------------------+--------------------------+ +// byte 1 | MQTT Control Packet Type | Flags for Packet type | +// +--------------------------+--------------------------+ +// | Remaining Length | +// +-----------------------------------------------------+ +// +// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 +// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub struct FixedHeader { + pub packet_type: PacketType, + pub flags: u8, + pub remaining_length: usize, +} + +impl FixedHeader { + pub fn read_fixed_header( + mut header: Iter, + ) -> Result<(Self, usize), ReadBytes> { + if header.len() < 2 { + return Err(ReadBytes::InsufficientBytes(2 - header.len())); + } + + let first_byte = header.next().unwrap(); + + let (packet_type, flags) = + PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; + + let (remaining_length, length) = read_fixed_header_rem_len(header)?; + + Ok(( + Self { + packet_type, + flags, + remaining_length, + }, + 1 + length, + )) + } +} + +/// 2.1.2 MQTT Control Packet type +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum PacketType { + Connect, + ConnAck, + Publish, + PubAck, + PubRec, + PubRel, + PubComp, + Subscribe, + SubAck, + Unsubscribe, + UnsubAck, + PingReq, + PingResp, + Disconnect, + Auth, +} +impl PacketType { + const fn from_first_byte(value: u8) -> Result<(Self, u8), DeserializeError> { + match (value >> 4, value & 0x0f) { + (0b0001, 0) => Ok((PacketType::Connect, 0)), + (0b0010, 0) => Ok((PacketType::ConnAck, 0)), + (0b0011, flags) => Ok((PacketType::Publish, flags)), + (0b0100, 0) => Ok((PacketType::PubAck, 0)), + (0b0101, 0) => Ok((PacketType::PubRec, 0)), + (0b0110, 0b0010) => Ok((PacketType::PubRel, 0)), + (0b0111, 0) => Ok((PacketType::PubComp, 0)), + (0b1000, 0b0010) => Ok((PacketType::Subscribe, 0)), + (0b1001, 0) => Ok((PacketType::SubAck, 0)), + (0b1010, 0b0010) => Ok((PacketType::Unsubscribe, 0)), + (0b1011, 0) => Ok((PacketType::UnsubAck, 0)), + (0b1100, 0) => Ok((PacketType::PingReq, 0)), + (0b1101, 0) => Ok((PacketType::PingResp, 0)), + (0b1110, 0) => Ok((PacketType::Disconnect, 0)), + (0b1111, 0) => Ok((PacketType::Auth, 0)), + (_, _) => Err(DeserializeError::UnknownFixedHeader(value)), + } + } +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + + use crate::packets::connack::{ConnAck, ConnAckFlags, ConnAckProperties}; + use crate::packets::disconnect::{Disconnect, DisconnectProperties}; + use crate::packets::QoS; + + use crate::packets::publish::{Publish, PublishProperties}; + use crate::packets::pubrel::{PubRel, PubRelProperties}; + use crate::packets::reason_codes::{ConnAckReasonCode, DisconnectReasonCode, PubRelReasonCode}; + use crate::packets::Packet; + + #[test] + fn test_connack_read() { + let connack = [ + 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, + 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, + ]; + let mut buf = BytesMut::new(); + buf.extend(connack); + + let res = Packet::read_from_buffer(&mut buf); + assert!(res.is_ok()); + let res = res.unwrap(); + + let expected = ConnAck { + connack_flags: ConnAckFlags::SESSION_PRESENT, + reason_code: ConnAckReasonCode::Success, + connack_properties: ConnAckProperties { + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: Some(true), + maximum_packet_size: Some(1048576), + assigned_client_id: None, + topic_alias_maximum: Some(65535), + reason_string: None, + user_properties: vec![], + wildcards_available: Some(true), + subscription_ids_available: Some(true), + shared_subscription_available: Some(true), + server_keep_alive: None, + response_info: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }, + }; + + assert_eq!(Packet::ConnAck(expected), res); + } + + #[test] + fn test_disconnect_read() { + let packet = [0xe0, 0x02, 0x8e, 0x00]; + let mut buf = BytesMut::new(); + buf.extend(packet); + + let res = Packet::read_from_buffer(&mut buf); + assert!(res.is_ok()); + let res = res.unwrap(); + + let expected = Disconnect { + reason_code: DisconnectReasonCode::SessionTakenOver, + properties: DisconnectProperties { + session_expiry_interval: None, + reason_string: None, + user_properties: vec![], + server_reference: None, + }, + }; + + assert_eq!(Packet::Disconnect(expected), res); + } + + #[test] + fn test_pingreq_read_write() { + let packet = [0xc0, 0x00]; + let mut buf = BytesMut::new(); + buf.extend(packet); + + let res = Packet::read_from_buffer(&mut buf); + assert!(res.is_ok()); + let res = res.unwrap(); + + assert_eq!(Packet::PingReq, res); + + buf.clear(); + Packet::PingReq.write(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), packet); + } + + #[test] + fn test_pingresp_read_write() { + let packet = [0xd0, 0x00]; + let mut buf = BytesMut::new(); + buf.extend(packet); + + let res = Packet::read_from_buffer(&mut buf); + assert!(res.is_ok()); + let res = res.unwrap(); + + assert_eq!(Packet::PingResp, res); + + buf.clear(); + Packet::PingResp.write(&mut buf).unwrap(); + assert_eq!(buf.to_vec(), packet); + } + + #[test] + fn test_publish_read() { + let packet = [ + 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, + 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, + 0x01, 0x09, 0x00, 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, + ]; + + let mut buf = BytesMut::new(); + buf.extend(packet); + + let res = Packet::read_from_buffer(&mut buf); + assert!(res.is_ok()); + let res = res.unwrap(); + + let expected = Publish { + dup: false, + qos: QoS::ExactlyOnce, + retain: true, + topic: "test/123/test/blabla".to_string(), + packet_identifier: Some(13779), + publish_properties: PublishProperties { + payload_format_indicator: Some(1), + message_expiry_interval: None, + topic_alias: None, + response_topic: None, + correlation_data: Some(Bytes::from_static(b"1212")), + subscription_identifier: vec![1], + user_properties: vec![], + content_type: None, + }, + payload: Bytes::from_static(b""), + }; + + assert_eq!(Packet::Publish(expected), res); + } + + #[test] + fn test_pubrel_read_write() { + let bytes = [0x62, 0x03, 0x35, 0xd3, 0x00]; + + let mut buffer = BytesMut::from_iter(bytes); + + let res = Packet::read_from_buffer(&mut buffer); + + assert!(res.is_ok()); + + let packet = res.unwrap(); + + let expected = PubRel { + packet_identifier: 13779, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: None, + user_properties: vec![], + }, + }; + + assert_eq!(packet, Packet::PubRel(expected)); + + buffer.clear(); + + packet.write(&mut buffer).unwrap(); + + // The input is not in the smallest possible format but when writing we do expect it to be in the smallest possible format. + assert_eq!(buffer.to_vec(), [0x62, 0x02, 0x35, 0xd3].to_vec()) + } + + #[test] + fn test_pubrel_read_smallest_format() { + let bytes = [0x62, 0x02, 0x35, 0xd3]; + + let mut buffer = BytesMut::from_iter(bytes); + + let res = Packet::read_from_buffer(&mut buffer); + + assert!(res.is_ok()); + + let packet = res.unwrap(); + + let expected = PubRel { + packet_identifier: 13779, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: None, + user_properties: vec![], + }, + }; + + assert_eq!(packet, Packet::PubRel(expected)); + + buffer.clear(); + + packet.write(&mut buffer).unwrap(); + + assert_eq!(buffer.to_vec(), bytes.to_vec()) + } +} diff --git a/src/packets/packets.rs b/src/packets/packets.rs deleted file mode 100644 index 8831eaa..0000000 --- a/src/packets/packets.rs +++ /dev/null @@ -1,531 +0,0 @@ -use core::slice::Iter; -use std::fmt::Display; - -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::error::{DeserializeError, ReadBytes, SerializeError}; -use super::mqtt_traits::{VariableHeaderRead, VariableHeaderWrite, WireLength}; -use super::{read_fixed_header_rem_len, write_variable_integer}; - -use super::auth::Auth; -use super::connack::ConnAck; -use super::connect::Connect; -use super::disconnect::Disconnect; -use super::puback::PubAck; -use super::pubcomp::PubComp; -use super::publish::Publish; -use super::pubrec::PubRec; -use super::pubrel::PubRel; -use super::suback::SubAck; -use super::subscribe::Subscribe; -use super::unsuback::UnsubAck; -use super::unsubscribe::Unsubscribe; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Packet { - Connect(Connect), - ConnAck(ConnAck), - Publish(Publish), - PubAck(PubAck), - PubRec(PubRec), - PubRel(PubRel), - PubComp(PubComp), - Subscribe(Subscribe), - SubAck(SubAck), - Unsubscribe(Unsubscribe), - UnsubAck(UnsubAck), - PingReq, - PingResp, - Disconnect(Disconnect), - Auth(Auth), -} - -impl Packet { - pub fn packet_type(&self) -> PacketType { - match self { - Packet::Connect(_) => PacketType::Connect, - Packet::ConnAck(_) => PacketType::ConnAck, - Packet::Publish(_) => PacketType::Publish, - Packet::PubAck(_) => PacketType::PubAck, - Packet::PubRec(_) => PacketType::PubRec, - Packet::PubRel(_) => PacketType::PubRel, - Packet::PubComp(_) => PacketType::PubComp, - Packet::Subscribe(_) => PacketType::Subscribe, - Packet::SubAck(_) => PacketType::SubAck, - Packet::Unsubscribe(_) => PacketType::Unsubscribe, - Packet::UnsubAck(_) => PacketType::UnsubAck, - Packet::PingReq => PacketType::PingReq, - Packet::PingResp => PacketType::PingResp, - Packet::Disconnect(_) => PacketType::Disconnect, - Packet::Auth(_) => PacketType::Auth, - } - } - - pub fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - match self { - Packet::Connect(p) => { - buf.put_u8(0b0001_0000); - write_variable_integer(buf, p.wire_len())?; - - p.write(buf)?; - } - Packet::ConnAck(_) => { - unreachable!() - } - Packet::Publish(p) => { - let mut first_byte = 0b0011_0000u8; - if p.dup { - first_byte |= 0b1000; - } - - first_byte |= p.qos.into_u8() << 1; - - if p.retain { - first_byte |= 0b0001; - } - buf.put_u8(first_byte); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::PubAck(p) => { - buf.put_u8(0b0100_0000); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::PubRec(p) => { - buf.put_u8(0b0101_0000); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::PubRel(p) => { - buf.put_u8(0b0110_0010); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::PubComp(p) => { - buf.put_u8(0b0111_0000); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::Subscribe(p) => { - buf.put_u8(0b1000_0010); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::SubAck(_) => { - unreachable!() - } - Packet::Unsubscribe(p) => { - buf.put_u8(0b1010_0010); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::UnsubAck(_) => { - buf.put_u8(0b1011_0000); - unreachable!() - } - Packet::PingReq => { - buf.put_u8(0b1100_0000); - buf.put_u8(0); // Variable header length. - } - Packet::PingResp => { - buf.put_u8(0b1101_0000); - buf.put_u8(0); // Variable header length. - } - Packet::Disconnect(p) => { - buf.put_u8(0b1110_0000); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - Packet::Auth(p) => { - buf.put_u8(0b1111_0000); - write_variable_integer(buf, p.wire_len())?; - p.write(buf)?; - } - } - Ok(()) - } - - pub fn read(header: FixedHeader, buf: Bytes) -> Result { - let packet = match header.packet_type { - PacketType::Connect => { - Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?) - } - PacketType::ConnAck => { - Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?) - } - PacketType::Publish => { - Packet::Publish(Publish::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubAck => { - Packet::PubAck(PubAck::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubRec => { - Packet::PubRec(PubRec::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubRel => { - Packet::PubRel(PubRel::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PubComp => { - Packet::PubComp(PubComp::read(header.flags, header.remaining_length, buf)?) - } - PacketType::Subscribe => { - Packet::Subscribe(Subscribe::read(header.flags, header.remaining_length, buf)?) - } - PacketType::SubAck => { - Packet::SubAck(SubAck::read(header.flags, header.remaining_length, buf)?) - } - PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read( - header.flags, - header.remaining_length, - buf, - )?), - PacketType::UnsubAck => { - Packet::UnsubAck(UnsubAck::read(header.flags, header.remaining_length, buf)?) - } - PacketType::PingReq => Packet::PingReq, - PacketType::PingResp => Packet::PingResp, - PacketType::Disconnect => Packet::Disconnect(Disconnect::read( - header.flags, - header.remaining_length, - buf, - )?), - PacketType::Auth => { - Packet::Auth(Auth::read(header.flags, header.remaining_length, buf)?) - } - }; - Ok(packet) - } - - pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> { - let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?; - - if header.remaining_length + header_length > buffer.len() { - return Err(ReadBytes::InsufficientBytes( - header.remaining_length - buffer.len(), - )); - } - buffer.advance(header_length); - - let buf = buffer.split_to(header.remaining_length); - - return Ok(Packet::read(header, buf.into())?); - } -} - -impl Display for Packet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self{ - Packet::Connect(c) => write!(f, "Connect(version: {:?}, clean: {}, username: {:?}, password: {:?}, keep_alive: {}, client_id: {})", c.protocol_version, c.clean_session, c.username, c.password, c.keep_alive, c.client_id), - Packet::ConnAck(c) => write!(f, "ConnAck(session:{:?}, reason code{:?})", c.connack_flags, c.reason_code), - Packet::Publish(p) => write!(f, "Publish(topic: {}, qos: {:?}, dup: {:?}, retain: {:?}, packet id: {:?})", &p.topic, p.qos, p.dup, p.retain, p.packet_identifier), - Packet::PubAck(p) => write!(f, "PubAck(id:{:?}, reason code: {:?})", p.packet_identifier, p.reason_code), - Packet::PubRec(p) => write!(f, "PubRec(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), - Packet::PubRel(p) => write!(f, "PubRel(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), - Packet::PubComp(p) => write!(f, "PubComp(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code), - Packet::Subscribe(_) => write!(f, "Subscribe()"), - Packet::SubAck(_) => write!(f, "SubAck()"), - Packet::Unsubscribe(_) => write!(f, "Unsubscribe()"), - Packet::UnsubAck(_) => write!(f, "UnsubAck()"), - Packet::PingReq => write!(f, "PingReq"), - Packet::PingResp => write!(f, "PingResp"), - Packet::Disconnect(d) => write!(f, "Disconnect(reason code: {:?})", d.reason_code), - Packet::Auth(_) => write!(f, "Auth()"), - } - } -} - -// 2.1.1 Fixed Header -// ``` -// 7 3 0 -// +--------------------------+--------------------------+ -// byte 1 | MQTT Control Packet Type | Flags for Packet type | -// +--------------------------+--------------------------+ -// | Remaining Length | -// +-----------------------------------------------------+ -// -// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 -// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] -pub struct FixedHeader { - pub packet_type: PacketType, - pub flags: u8, - pub remaining_length: usize, -} - -impl FixedHeader { - pub fn read_fixed_header( - mut header: Iter, - ) -> Result<(Self, usize), ReadBytes> { - if header.len() < 2 { - return Err(ReadBytes::InsufficientBytes(2 - header.len())); - } - - let first_byte = header.next().unwrap(); - - let (packet_type, flags) = - PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; - - let (remaining_length, length) = read_fixed_header_rem_len(header)?; - - Ok(( - Self { - packet_type, - flags, - remaining_length, - }, - 1 + length, - )) - } -} - -/// 2.1.2 MQTT Control Packet type -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum PacketType { - Connect, - ConnAck, - Publish, - PubAck, - PubRec, - PubRel, - PubComp, - Subscribe, - SubAck, - Unsubscribe, - UnsubAck, - PingReq, - PingResp, - Disconnect, - Auth, -} -impl PacketType { - const fn from_first_byte(value: u8) -> Result<(Self, u8), DeserializeError> { - match (value >> 4, value & 0x0f) { - (0b0001, 0) => Ok((PacketType::Connect, 0)), - (0b0010, 0) => Ok((PacketType::ConnAck, 0)), - (0b0011, flags) => Ok((PacketType::Publish, flags)), - (0b0100, 0) => Ok((PacketType::PubAck, 0)), - (0b0101, 0) => Ok((PacketType::PubRec, 0)), - (0b0110, 0b0010) => Ok((PacketType::PubRel, 0)), - (0b0111, 0) => Ok((PacketType::PubComp, 0)), - (0b1000, 0b0010) => Ok((PacketType::Subscribe, 0)), - (0b1001, 0) => Ok((PacketType::SubAck, 0)), - (0b1010, 0b0010) => Ok((PacketType::Unsubscribe, 0)), - (0b1011, 0) => Ok((PacketType::UnsubAck, 0)), - (0b1100, 0) => Ok((PacketType::PingReq, 0)), - (0b1101, 0) => Ok((PacketType::PingResp, 0)), - (0b1110, 0) => Ok((PacketType::Disconnect, 0)), - (0b1111, 0) => Ok((PacketType::Auth, 0)), - (_, _) => Err(DeserializeError::UnknownFixedHeader(value)), - } - } -} - -#[cfg(test)] -mod tests { - use bytes::{Bytes, BytesMut}; - - use crate::packets::connack::{ConnAck, ConnAckFlags, ConnAckProperties}; - use crate::packets::disconnect::{Disconnect, DisconnectProperties}; - use crate::packets::QoS; - - use crate::packets::packets::Packet; - use crate::packets::publish::{Publish, PublishProperties}; - use crate::packets::pubrel::{PubRel, PubRelProperties}; - use crate::packets::reason_codes::{ConnAckReasonCode, DisconnectReasonCode, PubRelReasonCode}; - - #[test] - fn test_connack_read() { - let connack = [ - 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, - 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, - ]; - let mut buf = BytesMut::new(); - buf.extend(connack); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = ConnAck { - connack_flags: ConnAckFlags::SESSION_PRESENT, - reason_code: ConnAckReasonCode::Success, - connack_properties: ConnAckProperties { - session_expiry_interval: None, - receive_maximum: None, - maximum_qos: None, - retain_available: Some(true), - maximum_packet_size: Some(1048576), - assigned_client_id: None, - topic_alias_maximum: Some(65535), - reason_string: None, - user_properties: vec![], - wildcards_available: Some(true), - subscription_ids_available: Some(true), - shared_subscription_available: Some(true), - server_keep_alive: None, - response_info: None, - server_reference: None, - authentication_method: None, - authentication_data: None, - }, - }; - - assert_eq!(Packet::ConnAck(expected), res); - } - - #[test] - fn test_disconnect_read() { - let packet = [0xe0, 0x02, 0x8e, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = Disconnect { - reason_code: DisconnectReasonCode::SessionTakenOver, - properties: DisconnectProperties { - session_expiry_interval: None, - reason_string: None, - user_properties: vec![], - server_reference: None, - }, - }; - - assert_eq!(Packet::Disconnect(expected), res); - } - - #[test] - fn test_pingreq_read_write() { - let packet = [0xc0, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - assert_eq!(Packet::PingReq, res); - - buf.clear(); - Packet::PingReq.write(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), packet); - } - - #[test] - fn test_pingresp_read_write() { - let packet = [0xd0, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - assert_eq!(Packet::PingResp, res); - - buf.clear(); - Packet::PingResp.write(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), packet); - } - - #[test] - fn test_publish_read() { - let packet = [ - 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, - 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, - 0x01, 0x09, 0x00, 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, - ]; - - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = Publish { - dup: false, - qos: QoS::ExactlyOnce, - retain: true, - topic: "test/123/test/blabla".to_string(), - packet_identifier: Some(13779), - publish_properties: PublishProperties { - payload_format_indicator: Some(1), - message_expiry_interval: None, - topic_alias: None, - response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), - subscription_identifier: vec![1], - user_properties: vec![], - content_type: None, - }, - payload: Bytes::from_static(b""), - }; - - assert_eq!(Packet::Publish(expected), res); - } - - #[test] - fn test_pubrel_read_write() { - let bytes = [0x62, 0x03, 0x35, 0xd3, 0x00]; - - let mut buffer = BytesMut::from_iter(&bytes); - - let res = Packet::read_from_buffer(&mut buffer); - - assert!(res.is_ok()); - - let packet = res.unwrap(); - - let expected = PubRel { - packet_identifier: 13779, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties { - reason_string: None, - user_properties: vec![], - }, - }; - - assert_eq!(packet, Packet::PubRel(expected)); - - buffer.clear(); - - packet.write(&mut buffer).unwrap(); - - // The input is not in the smallest possible format but when writing we do expect it to be in the smallest possible format. - assert_eq!(buffer.to_vec(), [0x62, 0x02, 0x35, 0xd3].to_vec()) - } - - #[test] - fn test_pubrel_read_smallest_format() { - let bytes = [0x62, 0x02, 0x35, 0xd3]; - - let mut buffer = BytesMut::from_iter(&bytes); - - let res = Packet::read_from_buffer(&mut buffer); - - assert!(res.is_ok()); - - let packet = res.unwrap(); - - let expected = PubRel { - packet_identifier: 13779, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties { - reason_string: None, - user_properties: vec![], - }, - }; - - assert_eq!(packet, Packet::PubRel(expected)); - - buffer.clear(); - - packet.write(&mut buffer).unwrap(); - - assert_eq!(buffer.to_vec(), bytes.to_vec()) - } -} diff --git a/src/packets/puback.rs b/src/packets/puback.rs index bf22d0d..493b3ca 100644 --- a/src/packets/puback.rs +++ b/src/packets/puback.rs @@ -56,13 +56,15 @@ impl VariableHeaderWrite for PubAck { if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() { - () + && self.properties.user_properties.is_empty() + { + // nothing here } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { self.reason_code.write(buf)?; - } + } else { self.reason_code.write(buf)?; self.properties.write(buf)?; @@ -75,11 +77,13 @@ impl WireLength for PubAck { fn wire_len(&self) -> usize { if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 2 - } + } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 3 } else { @@ -184,7 +188,6 @@ mod tests { }; use bytes::{BufMut, Bytes, BytesMut}; - #[test] fn test_wire_len() { let mut puback = PubAck { @@ -196,10 +199,10 @@ mod tests { let mut buf = BytesMut::new(); puback.write(&mut buf).unwrap(); - + assert_eq!(2, puback.wire_len()); assert_eq!(2, buf.len()); - + puback.reason_code = PubAckReasonCode::NotAuthorized; buf.clear(); puback.write(&mut buf).unwrap(); @@ -208,7 +211,6 @@ mod tests { assert_eq!(3, buf.len()); } - #[test] fn test_read_simple_puback() { let stream = &[ diff --git a/src/packets/pubcomp.rs b/src/packets/pubcomp.rs index ec60ede..b1fdfdd 100644 --- a/src/packets/pubcomp.rs +++ b/src/packets/pubcomp.rs @@ -66,13 +66,15 @@ impl VariableHeaderWrite for PubComp { if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() { - () + && self.properties.user_properties.is_empty() + { + // nothing here } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { self.reason_code.write(buf)?; - } + } else { self.reason_code.write(buf)?; self.properties.write(buf)?; @@ -85,11 +87,13 @@ impl WireLength for PubComp { fn wire_len(&self) -> usize { if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 2 - } + } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 3 } else { @@ -98,7 +102,7 @@ impl WireLength for PubComp { } } -#[derive(Debug, PartialEq, Eq, Clone, Hash)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] pub struct PubCompProperties { pub reason_string: Option, pub user_properties: Vec<(String, String)>, @@ -110,15 +114,6 @@ impl PubCompProperties { } } -impl Default for PubCompProperties { - fn default() -> Self { - Self { - reason_string: Default::default(), - user_properties: Default::default(), - } - } -} - impl MqttRead for PubCompProperties { fn read(buf: &mut bytes::Bytes) -> Result { let (len, _) = read_variable_integer(buf)?; @@ -203,7 +198,6 @@ mod tests { }; use bytes::{BufMut, Bytes, BytesMut}; - #[test] fn test_wire_len() { let mut pubcomp = PubComp { diff --git a/src/packets/publish.rs b/src/packets/publish.rs index cbe7227..5262075 100644 --- a/src/packets/publish.rs +++ b/src/packets/publish.rs @@ -101,7 +101,8 @@ impl WireLength for Publish { let len = self.topic.wire_len() + if self.packet_identifier.is_some() { 2 - } else { + } + else { 0 } + self.publish_properties.wire_len() @@ -152,7 +153,8 @@ impl MqttRead for PublishProperties { if len == 0 { return Ok(Self::default()); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "PublishProperties".to_string(), buf.len(), diff --git a/src/packets/pubrec.rs b/src/packets/pubrec.rs index 448ee1a..737463c 100644 --- a/src/packets/pubrec.rs +++ b/src/packets/pubrec.rs @@ -65,13 +65,15 @@ impl VariableHeaderWrite for PubRec { if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() { - () + && self.properties.user_properties.is_empty() + { + // nothing here } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { self.reason_code.write(buf)?; - } + } else { self.reason_code.write(buf)?; self.properties.write(buf)?; @@ -84,11 +86,13 @@ impl WireLength for PubRec { fn wire_len(&self) -> usize { if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 2 - } + } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 3 } else { diff --git a/src/packets/pubrel.rs b/src/packets/pubrel.rs index 985ebbf..3c08f3b 100644 --- a/src/packets/pubrel.rs +++ b/src/packets/pubrel.rs @@ -38,13 +38,15 @@ impl VariableHeaderRead for PubRel { reason_code: PubRelReasonCode::Success, properties: PubRelProperties::default(), }) - } else if remaining_length == 3 { + } + else if remaining_length == 3 { Ok(Self { packet_identifier: u16::read(&mut buf)?, reason_code: PubRelReasonCode::read(&mut buf)?, properties: PubRelProperties::default(), }) - } else { + } + else { Ok(Self { packet_identifier: u16::read(&mut buf)?, reason_code: PubRelReasonCode::read(&mut buf)?, @@ -60,13 +62,15 @@ impl VariableHeaderWrite for PubRel { if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty() { - () + && self.properties.user_properties.is_empty() + { + // Nothing here } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { self.reason_code.write(buf)?; - } + } else { self.reason_code.write(buf)?; self.properties.write(buf)?; @@ -79,11 +83,13 @@ impl WireLength for PubRel { fn wire_len(&self) -> usize { if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 2 - } + } else if self.properties.reason_string.is_none() - && self.properties.user_properties.is_empty(){ + && self.properties.user_properties.is_empty() + { 3 } else { @@ -188,7 +194,6 @@ mod tests { }; use bytes::{BufMut, Bytes, BytesMut}; - #[test] fn test_wire_len() { let mut pubrel = PubRel { @@ -227,18 +232,17 @@ mod tests { assert_eq!(2, buf.len()); let pubrel = PubRel::read(0, 2, buf.into()).unwrap(); - + assert_eq!(expected_pubrel, pubrel); - + let mut buf = BytesMut::new(); expected_pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound; expected_pubrel.write(&mut buf).unwrap(); - + assert_eq!(3, buf.len()); - + let pubrel = PubRel::read(0, 3, buf.into()).unwrap(); assert_eq!(expected_pubrel, pubrel); - } #[test] diff --git a/src/packets/suback.rs b/src/packets/suback.rs index f86fa40..31dc5bb 100644 --- a/src/packets/suback.rs +++ b/src/packets/suback.rs @@ -77,7 +77,8 @@ impl MqttRead for SubAckProperties { if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "SubAckProperties".to_string(), buf.len(), @@ -94,7 +95,8 @@ impl MqttRead for SubAckProperties { let (subscription_id, _) = read_variable_integer(&mut properties_data)?; properties.subscription_id = Some(subscription_id); - } else { + } + else { return Err(DeserializeError::DuplicateProperty( PropertyType::SubscriptionIdentifier, )); diff --git a/src/packets/subscribe.rs b/src/packets/subscribe.rs index 2a9c12a..94a1b5f 100644 --- a/src/packets/subscribe.rs +++ b/src/packets/subscribe.rs @@ -96,7 +96,8 @@ impl MqttRead for SubscribeProperties { if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "SubscribeProperties".to_string(), buf.len(), @@ -113,7 +114,8 @@ impl MqttRead for SubscribeProperties { let (subscription_id, _) = read_variable_integer(&mut properties_data)?; properties.subscription_id = Some(subscription_id); - } else { + } + else { return Err(DeserializeError::DuplicateProperty( PropertyType::SubscriptionIdentifier, )); @@ -267,7 +269,7 @@ mod tests { use crate::packets::{ mqtt_traits::{MqttRead, VariableHeaderRead, VariableHeaderWrite}, - packets::Packet, + Packet, }; use super::WireLength; diff --git a/src/packets/unsuback.rs b/src/packets/unsuback.rs index da1aa1b..94536bf 100644 --- a/src/packets/unsuback.rs +++ b/src/packets/unsuback.rs @@ -71,7 +71,8 @@ impl MqttRead for UnsubAckProperties { if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "UnsubAckProperties".to_string(), buf.len(), @@ -86,7 +87,8 @@ impl MqttRead for UnsubAckProperties { PropertyType::ReasonString => { if properties.reason_string.is_none() { properties.reason_string = Some(String::read(&mut properties_data)?); - } else { + } + else { return Err(DeserializeError::DuplicateProperty( PropertyType::SubscriptionIdentifier, )); diff --git a/src/packets/unsubscribe.rs b/src/packets/unsubscribe.rs index db81eda..225410c 100644 --- a/src/packets/unsubscribe.rs +++ b/src/packets/unsubscribe.rs @@ -76,7 +76,8 @@ impl MqttRead for UnsubscribeProperties { if len == 0 { return Ok(properties); - } else if buf.len() < len { + } + else if buf.len() < len { return Err(DeserializeError::InsufficientData( "UnsubscribeProperties".to_string(), buf.len(), diff --git a/src/smol_network.rs b/src/smol_network.rs new file mode 100644 index 0000000..eab20e9 --- /dev/null +++ b/src/smol_network.rs @@ -0,0 +1,166 @@ +use async_channel::{Receiver, Sender}; + +use futures::{select, FutureExt}; +use smol::io::{AsyncReadExt, AsyncWriteExt}; + +use std::time::{Duration, Instant}; + +use crate::connect_options::ConnectOptions; +use crate::connections::smol_stream::SmolStream; +use crate::error::ConnectionError; +use crate::packets::error::ReadBytes; +use crate::packets::{Packet, PacketType}; +use crate::NetworkStatus; + +pub struct SmolNetwork { + network: Option>, + + /// Options of the current mqtt connection + options: ConnectOptions, + + last_network_action: Instant, + await_pingresp: Option, + perform_keep_alive: bool, + + network_to_handler_s: Sender, + + to_network_r: Receiver, +} + +impl SmolNetwork +where + S: AsyncReadExt + AsyncWriteExt + Sized + Unpin, +{ + pub fn new( + options: ConnectOptions, + network_to_handler_s: Sender, + to_network_r: Receiver, + ) -> Self { + Self { + network: None, + + options, + + last_network_action: Instant::now(), + await_pingresp: None, + perform_keep_alive: true, + + network_to_handler_s, + to_network_r, + } + } + + pub fn reset(&mut self) { + self.network = None; + } + + pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> { + let (network, connack) = SmolStream::connect(&self.options, stream).await?; + + self.network = Some(network); + self.network_to_handler_s.send(connack).await?; + self.last_network_action = Instant::now(); + if self.options.keep_alive_interval_s == 0 { + self.perform_keep_alive = false; + } + Ok(()) + } + + pub async fn run(&mut self) -> Result { + if self.network.is_none() { + return Err(ConnectionError::NoNetwork); + } + + match self.select().await { + Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active), + otherwise => { + self.reset(); + otherwise + } + } + } + + async fn select(&mut self) -> Result { + if self.network.is_none() { + return Err(ConnectionError::NoNetwork); + } + + let SmolNetwork { + network, + options: _, + last_network_action, + await_pingresp, + perform_keep_alive, + network_to_handler_s, + to_network_r, + } = self; + + let sleep; + if !(*perform_keep_alive) { + sleep = Duration::new(3600, 0); + } + else if let Some(instant) = await_pingresp { + sleep = + *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); + } + else { + sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s) + - Instant::now(); + } + + if let Some(stream) = network { + loop { + select! { + _ = stream.read_bytes().fuse() => { + match stream.parse_messages(network_to_handler_s).await { + Err(ReadBytes::Err(err)) => return Err(err), + Err(ReadBytes::InsufficientBytes(_)) => continue, + Ok(Some(PacketType::PingResp)) => { + *await_pingresp = None; + return Ok(NetworkStatus::Active) + }, + Ok(Some(PacketType::Disconnect)) => { + return Ok(NetworkStatus::IncomingDisconnect) + }, + Ok(_) => { + return Ok(NetworkStatus::Active) + } + }; + }, + outgoing = to_network_r.recv().fuse() => { + let packet = outgoing?; + stream.write(&packet).await?; + *last_network_action = Instant::now(); + if packet.packet_type() == PacketType::Disconnect{ + return Ok(NetworkStatus::OutgoingDisconnect); + } + return Ok(NetworkStatus::Active); + }, + _ = smol::Timer::after(sleep).fuse() => { + if await_pingresp.is_none() && *perform_keep_alive{ + let packet = Packet::PingReq; + stream.write(&packet).await?; + *last_network_action = Instant::now(); + *await_pingresp = Some(Instant::now()); + return Ok(NetworkStatus::Active); + } + else{ + return Ok(NetworkStatus::NoPingResp); + } + }, + } + } + } + else { + Err(ConnectionError::NoNetwork) + } + } +} + +#[test] +fn test() { + let a = Instant::now() - Duration::from_secs(100); + + let sleep = a + Duration::from_secs(60) - Instant::now(); + dbg!(sleep); +} diff --git a/src/state.rs b/src/state.rs index de0e9fc..0aaeeac 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,7 +1,6 @@ use std::collections::{BTreeMap, BTreeSet}; use async_channel::Receiver; -use async_mutex::Mutex; use crate::{ available_packet_ids::AvailablePacketIds, diff --git a/src/tests/handler_tests.rs b/src/tests/handler_tests.rs deleted file mode 100644 index 8b13789..0000000 --- a/src/tests/handler_tests.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 6e5737d..0525f44 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,4 +1,2 @@ mod connection_tests; -mod handler_tests; pub mod resources; -pub mod stages; diff --git a/src/tests/resources/test_packets.rs b/src/tests/resources/test_packets.rs index dabf117..91e8ef3 100644 --- a/src/tests/resources/test_packets.rs +++ b/src/tests/resources/test_packets.rs @@ -1,14 +1,9 @@ -#[allow(dead_code)] use bytes::Bytes; use crate::packets::{ - Disconnect, DisconnectProperties, - Packet, - PubAck, PubAckProperties, - Publish, PublishProperties, - Subscribe, Subscription, - QoS, - reason_codes::{PubAckReasonCode, DisconnectReasonCode}, + reason_codes::{DisconnectReasonCode, PubAckReasonCode}, + Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, Publish, PublishProperties, + QoS, Subscribe, Subscription, }; pub fn publish_packets() -> Vec { diff --git a/src/tests/stages.rs b/src/tests/stages.rs deleted file mode 100644 index 44f729f..0000000 --- a/src/tests/stages.rs +++ /dev/null @@ -1,69 +0,0 @@ -use tracing::{trace, warn}; - -use crate::{event_handler::AsyncEventHandler, packets::Packet}; - -pub struct Nop { } -impl AsyncEventHandler for Nop { - fn handle<'a>( - &'a mut self, - event: &'a Packet, - ) -> impl core::future::Future + Send + 'a { - async move { - // warn!("{:?}", event) - } - } -} - - -pub mod qos_2 { - use crate::{ - client::AsyncClient, - event_handler::AsyncEventHandler, packets::{Packet, PacketType}, - // packets::{Packet, PacketType}, - }; - - pub struct TestPubQoS2 { - stage: StagePubQoS2, - client: AsyncClient, - } - pub enum StagePubQoS2 { - ConnAck, - PubRec, - PubComp, - Done, - } - impl TestPubQoS2 { - #[allow(dead_code)] - pub fn new(client: AsyncClient) -> Self { - TestPubQoS2 { - stage: StagePubQoS2::ConnAck, - client, - } - } - } - impl AsyncEventHandler for TestPubQoS2 { - fn handle<'a>( - &'a mut self, - event: &'a Packet, - ) -> impl core::future::Future + Send + 'a { - async move { - match self.stage { - StagePubQoS2::ConnAck => { - assert_eq!(event.packet_type(), PacketType::ConnAck); - self.stage = StagePubQoS2::PubRec; - } - StagePubQoS2::PubRec => { - assert_eq!(event.packet_type(), PacketType::PubRec); - self.stage = StagePubQoS2::PubComp; - } - StagePubQoS2::PubComp => { - assert_eq!(event.packet_type(), PacketType::PubComp); - self.stage = StagePubQoS2::Done; - self.client.disconnect().await.unwrap(); - } - StagePubQoS2::Done => (), - } - } - } - } -} diff --git a/src/tokio_network.rs b/src/tokio_network.rs new file mode 100644 index 0000000..79585d8 --- /dev/null +++ b/src/tokio_network.rs @@ -0,0 +1,156 @@ +use async_channel::{Receiver, Sender}; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use std::time::{Duration, Instant}; + +use crate::connect_options::ConnectOptions; +use crate::connections::tokio_stream::TokioStream; +use crate::error::ConnectionError; +use crate::packets::error::ReadBytes; +use crate::packets::{Packet, PacketType}; +use crate::NetworkStatus; + +pub struct TokioNetwork { + network: Option>, + + /// Options of the current mqtt connection + options: ConnectOptions, + + last_network_action: Instant, + await_pingresp: Option, + perform_keep_alive: bool, + + network_to_handler_s: Sender, + + to_network_r: Receiver, +} + +impl TokioNetwork +where + S: AsyncReadExt + AsyncWriteExt + Sized + Unpin, +{ + pub fn new( + options: ConnectOptions, + network_to_handler_s: Sender, + to_network_r: Receiver, + ) -> Self { + Self { + network: None, + + options, + + last_network_action: Instant::now(), + await_pingresp: None, + perform_keep_alive: true, + + network_to_handler_s, + to_network_r, + } + } + + pub fn reset(&mut self) { + self.network = None; + } + + pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> { + let (network, connack) = TokioStream::connect(&self.options, stream).await?; + + self.network = Some(network); + self.network_to_handler_s.send(connack).await?; + self.last_network_action = Instant::now(); + if self.options.keep_alive_interval_s == 0 { + self.perform_keep_alive = false; + } + Ok(()) + } + + pub async fn run(&mut self) -> Result { + if self.network.is_none() { + return Err(ConnectionError::NoNetwork); + } + + match self.select().await { + Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active), + otherwise => { + self.reset(); + otherwise + } + } + } + + async fn select(&mut self) -> Result { + let TokioNetwork { + network, + options: _, + last_network_action, + await_pingresp, + perform_keep_alive, + network_to_handler_s, + to_network_r, + } = self; + + let sleep; + if let Some(instant) = await_pingresp { + sleep = + *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now(); + } + else { + sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s) + - Instant::now(); + } + + if let Some(stream) = network { + loop { + tokio::select! { + _ = stream.read_bytes() => { + match stream.parse_messages(network_to_handler_s).await { + Err(ReadBytes::Err(err)) => return Err(err), + Err(ReadBytes::InsufficientBytes(_)) => continue, + Ok(Some(PacketType::PingResp)) => { + *await_pingresp = None; + return Ok(NetworkStatus::Active) + }, + Ok(Some(PacketType::Disconnect)) => { + return Ok(NetworkStatus::IncomingDisconnect) + }, + Ok(_) => { + return Ok(NetworkStatus::Active) + } + }; + }, + outgoing = to_network_r.recv() => { + let packet = outgoing?; + stream.write(&packet).await?; + *last_network_action = Instant::now(); + if packet.packet_type() == PacketType::Disconnect{ + return Ok(NetworkStatus::OutgoingDisconnect); + } + return Ok(NetworkStatus::Active); + }, + _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { + let packet = Packet::PingReq; + stream.write(&packet).await?; + *last_network_action = Instant::now(); + *await_pingresp = Some(Instant::now()); + return Ok(NetworkStatus::Active); + }, + _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { + return Ok(NetworkStatus::NoPingResp); + } + } + } + } + else { + Err(ConnectionError::NoNetwork) + } + } +} + +#[test] +fn test() { + let a = Instant::now() - Duration::from_secs(100); + + let sleep = a + Duration::from_secs(60) - Instant::now(); + dbg!(sleep); +} diff --git a/src/util/mod.rs b/src/util/mod.rs index e5399c5..52ab346 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,2 +1,3 @@ pub mod constants; pub mod timeout; +pub mod tls; diff --git a/src/util/timeout.rs b/src/util/timeout.rs index 02bd5b4..69f175a 100644 --- a/src/util/timeout.rs +++ b/src/util/timeout.rs @@ -1,29 +1,29 @@ -use std::fmt::Display; +// use std::fmt::Display; -use futures_concurrency::future::Race; +// use futures_concurrency::future::Race; -#[derive(Debug, Clone, Copy)] -pub struct Timeout(()); +// #[derive(Debug, Clone, Copy)] +// pub struct Timeout(()); -impl Display for Timeout { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Timeout") - } -} +// impl Display for Timeout { +// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +// write!(f, "Timeout") +// } +// } -impl std::error::Error for Timeout {} +// impl std::error::Error for Timeout {} -pub async fn timeout( - fut: impl core::future::Future, - delay_seconds: u64, -) -> Result { - (async { Ok(fut.await) }, async { - #[cfg(feature = "smol")] - smol::Timer::after(std::time::Duration::from_secs(delay_seconds)).await; - #[cfg(feature = "tokio")] - tokio::time::sleep(tokio::time::Duration::from_secs(delay_seconds)).await; - Err(Timeout(())) - }) - .race() - .await -} +// pub async fn timeout( +// fut: impl core::future::Future, +// delay_seconds: u64, +// ) -> Result { +// (async { Ok(fut.await) }, async { +// #[cfg(feature = "smol")] +// smol::Timer::after(std::time::Duration::from_secs(delay_seconds)).await; +// #[cfg(feature = "tokio")] +// tokio::time::sleep(tokio::time::Duration::from_secs(delay_seconds)).await; +// Err(Timeout(())) +// }) +// .race() +// .await +// } diff --git a/src/util/tls.rs b/src/util/tls.rs new file mode 100644 index 0000000..39c4906 --- /dev/null +++ b/src/util/tls.rs @@ -0,0 +1,75 @@ +#[cfg(test)] +pub mod tests { + use std::{ + io::{BufReader, Cursor}, + sync::Arc, + }; + + use rustls::{Certificate, ClientConfig, Error, OwnedTrustAnchor, RootCertStore}; + + #[derive(Debug, Clone)] + pub enum PrivateKey { + RSA(Vec), + ECC(Vec), + } + + pub fn simple_rust_tls( + ca: Vec, + alpn: Option>>, + client_auth: Option<(Vec, PrivateKey)>, + ) -> Result, Error> { + let mut root_cert_store = RootCertStore::empty(); + + let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca))).unwrap(); + + let trust_anchors = ca_certs.iter().map_while(|cert| { + if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { + Some(OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + )) + } + else { + None + } + }); + root_cert_store.add_server_trust_anchors(trust_anchors); + + assert!(!root_cert_store.is_empty()); + + let config = ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_cert_store); + + let mut config = match client_auth { + Some((client_cert_info, client_private_info)) => { + let read_private_keys = match client_private_info { + PrivateKey::RSA(rsa) => { + rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))) + } + PrivateKey::ECC(ecc) => { + rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))) + } + } + .unwrap(); + + let key = read_private_keys.into_iter().next().unwrap(); + + let client_certs = + rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info))) + .unwrap(); + let client_cert_chain = client_certs.into_iter().map(Certificate).collect(); + + config.with_single_cert(client_cert_chain, rustls::PrivateKey(key))? + } + None => config.with_no_client_auth(), + }; + + if let Some(alpn) = alpn { + config.alpn_protocols.extend(alpn) + } + + Ok(Arc::new(config)) + } +}