From 7cd0424e8a3302d5ef90e7d4c9eb809056bc0ff8 Mon Sep 17 00:00:00 2001
From: Gunnar <13799935+GunnarMorrigan@users.noreply.github.com>
Date: Tue, 10 Jan 2023 11:36:02 +0100
Subject: [PATCH] Changed connections to allow user to provide the stream (#4)
* Added keep alive to tokio network
* Removed nested packets mod
* Remove nightly toolchain and async_fn_in_trait feature
* Removed more code dependent on async_fn_in_trait
* Adjusted smol stream to return pingresp and disconnect indicators too.
* Removed atomic_disconnect and added return info
Removed atomic disconnect from event handler and added enum to return status of the event handler.
* Removed commented section
* Added return info and keep alive
* Added async_trait for handler
* Removed stages, they don't work atm
* Adjusted license to MPL
* Removed warn from tokio_network
* Fixed tokio example test in lib
* Always perform reset in the needed cases
* Adjust lib tests
* fmt and clippy
* Adjusted trait bounds
* Adjusted doc example
* Fix lib test cases
* rustfmt
* Remove vscode dir
* Add license file
* Added codecov badge and adjusted version
* Fix license cargo deny
---
.vscode/settings.json | 9 -
Cargo.lock | 85 +--
Cargo.toml | 39 +-
LICENSE | 373 +++++++++++
README.md | 27 +-
deny.toml | 17 +-
rust-toolchain.toml | 2 -
rustfmt.toml | 3 +
src/available_packet_ids.rs | 32 +-
src/client.rs | 11 +-
src/connect_options.rs | 93 +--
src/connections/async_native_tls.rs | 236 -------
src/connections/async_rustls.rs | 297 ---------
src/connections/mod.rs | 81 +--
src/connections/smol_stream.rs | 176 +++++
src/connections/smol_tcp.rs | 220 -------
src/connections/tokio_rustls.rs | 290 ---------
src/connections/tokio_stream.rs | 174 +++++
src/connections/tokio_tcp.rs | 223 -------
src/connections/transport.rs | 31 -
src/connections/util.rs | 74 ---
src/error.rs | 43 +-
src/event_handler.rs | 975 ++++++++++++++--------------
src/lib.rs | 580 ++++++++++++-----
src/network.rs | 121 ----
src/packets/auth.rs | 3 +-
src/packets/connack.rs | 3 +-
src/packets/connect.rs | 12 +-
src/packets/disconnect.rs | 9 +-
src/packets/mod.rs | 546 +++++++++++++++-
src/packets/packets.rs | 531 ---------------
src/packets/puback.rs | 24 +-
src/packets/pubcomp.rs | 30 +-
src/packets/publish.rs | 6 +-
src/packets/pubrec.rs | 18 +-
src/packets/pubrel.rs | 34 +-
src/packets/suback.rs | 6 +-
src/packets/subscribe.rs | 8 +-
src/packets/unsuback.rs | 6 +-
src/packets/unsubscribe.rs | 3 +-
src/smol_network.rs | 166 +++++
src/state.rs | 1 -
src/tests/handler_tests.rs | 1 -
src/tests/mod.rs | 2 -
src/tests/resources/test_packets.rs | 11 +-
src/tests/stages.rs | 69 --
src/tokio_network.rs | 156 +++++
src/util/mod.rs | 1 +
src/util/timeout.rs | 48 +-
src/util/tls.rs | 75 +++
50 files changed, 2803 insertions(+), 3178 deletions(-)
delete mode 100644 .vscode/settings.json
create mode 100644 LICENSE
delete mode 100644 rust-toolchain.toml
create mode 100644 rustfmt.toml
delete mode 100644 src/connections/async_native_tls.rs
delete mode 100644 src/connections/async_rustls.rs
create mode 100644 src/connections/smol_stream.rs
delete mode 100644 src/connections/smol_tcp.rs
delete mode 100644 src/connections/tokio_rustls.rs
create mode 100644 src/connections/tokio_stream.rs
delete mode 100644 src/connections/tokio_tcp.rs
delete mode 100644 src/connections/transport.rs
delete mode 100644 src/connections/util.rs
delete mode 100644 src/network.rs
delete mode 100644 src/packets/packets.rs
create mode 100644 src/smol_network.rs
delete mode 100644 src/tests/handler_tests.rs
delete mode 100644 src/tests/stages.rs
create mode 100644 src/tokio_network.rs
create mode 100644 src/util/tls.rs
diff --git a/.vscode/settings.json b/.vscode/settings.json
deleted file mode 100644
index cf4cb03..0000000
--- a/.vscode/settings.json
+++ /dev/null
@@ -1,9 +0,0 @@
-{
- "cSpell.words": [
- "acked",
- "QUIC",
- "runtimes",
- "smol",
- "unacked"
- ]
-}
\ No newline at end of file
diff --git a/Cargo.lock b/Cargo.lock
index 0689cb2..a76ad69 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -125,6 +125,17 @@ version = "4.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a40729d2133846d9ed0ea60a8b9541bccddab49cd30f0715a1da672fe9a2524"
+[[package]]
+name = "async-trait"
+version = "0.1.61"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "705339e0e4a9690e2908d2b3d049d85682cf19fbd5782494498fbf7003a6a282"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
[[package]]
name = "atomic-waker"
version = "1.0.0"
@@ -149,18 +160,6 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
-[[package]]
-name = "bitvec"
-version = "1.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
-dependencies = [
- "funty",
- "radium",
- "tap",
- "wyz",
-]
-
[[package]]
name = "blocking"
version = "1.3.0"
@@ -248,12 +247,6 @@ dependencies = [
"instant",
]
-[[package]]
-name = "funty"
-version = "2.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
-
[[package]]
name = "futures"
version = "0.3.25"
@@ -278,17 +271,6 @@ dependencies = [
"futures-sink",
]
-[[package]]
-name = "futures-concurrency"
-version = "7.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a740c32e1bde284ce2f51df98abd4fa38e9e539670443c111211777e3ab09927"
-dependencies = [
- "bitvec",
- "futures-core",
- "pin-project",
-]
-
[[package]]
name = "futures-core"
version = "0.3.25"
@@ -434,15 +416,15 @@ dependencies = [
[[package]]
name = "mqrstt"
-version = "0.1.0"
+version = "0.1.1"
dependencies = [
"async-channel",
"async-mutex",
"async-rustls",
+ "async-trait",
"bitflags",
"bytes",
"futures",
- "futures-concurrency",
"pretty_assertions",
"rustls",
"rustls-pemfile",
@@ -502,26 +484,6 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "427c3892f9e783d91cc128285287e70a59e206ca452770ece88a76f7a3eddd72"
-[[package]]
-name = "pin-project"
-version = "1.0.12"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
-dependencies = [
- "pin-project-internal",
-]
-
-[[package]]
-name = "pin-project-internal"
-version = "1.0.12"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
-dependencies = [
- "proc-macro2",
- "quote",
- "syn",
-]
-
[[package]]
name = "pin-project-lite"
version = "0.2.9"
@@ -578,12 +540,6 @@ dependencies = [
"proc-macro2",
]
-[[package]]
-name = "radium"
-version = "0.7.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
-
[[package]]
name = "regex"
version = "1.7.0"
@@ -741,12 +697,6 @@ dependencies = [
"unicode-ident",
]
-[[package]]
-name = "tap"
-version = "1.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
-
[[package]]
name = "thiserror"
version = "1.0.37"
@@ -1070,15 +1020,6 @@ version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5"
-[[package]]
-name = "wyz"
-version = "0.5.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
-dependencies = [
- "tap",
-]
-
[[package]]
name = "yansi"
version = "0.5.1"
diff --git a/Cargo.toml b/Cargo.toml
index 8b1ba1c..7f409aa 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,23 +1,23 @@
[package]
name = "mqrstt"
-version = "0.1.0"
+version = "0.1.1"
+homepage = "https://github.com/GunnarMorrigan/mqrstt"
+repository = "https://github.com/GunnarMorrigan/mqrstt"
+categories = ["network-programming"]
+readme = "README.md"
edition = "2021"
-license = "Apache-2.0"
+license = "MPL-2.0"
+keywords = [ "MQTT", "IoT", "MQTTv5", "messaging", "client" ]
+description = "Pure rust MQTTv5 client implementation for Smol, Tokio and soon sync too."
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
-default = ["smol", "tcp"]
+default = ["smol", "tokio"]
tokio = ["dep:tokio"]
smol = ["dep:smol"]
-tcp = []
-tokio-rustls = ["dep:tokio-rustls", "dep:rustls", "dep:rustls-pemfile", "dep:webpki"]
-smol-rustls = ["dep:async-rustls", "dep:rustls", "dep:rustls-pemfile", "dep:webpki"]
-
-# If in the future we only provide an MQTT packet handler to make it async, sync and runtime agnostic
-# then this is not needed
# quic = ["dep:quinn"]
-# native-tls = ["dep:async-native-tls"]
+
[dependencies]
# Packets
@@ -27,32 +27,29 @@ bitflags = "1.3.2"
# Errors
thiserror = "1.0.37"
tracing = "0.1.37"
-tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
async-channel = "1.8.0"
async-mutex = "1.4.0"
-futures-concurrency = "7.0.0"
futures = { version = "0.3.25", default-features = false, features = ["std", "async-await"] }
-
-# Needed for all TLS
-rustls = { version = "0.20.7", optional = true }
-rustls-pemfile = { version = "1.0.1", optional = true }
-webpki = { version = "0.22.0", optional = true }
+async-trait = "0.1.61"
# quic feature flag
# quinn = {version = "0.9.0", optional = true }
# tokio feature flag
tokio = { version = "1.21", features = ["macros", "io-util", "net", "time"], optional = true }
-tokio-rustls = { version = "0.23.4", optional = true }
# smol feature flag
smol = { version = "1.3.0", optional = true }
-async-rustls = { version = "0.3.0", optional = true }
-#async-native-tls = { version = "0.4.0", optional = true }
[dev-dependencies]
pretty_assertions = "1.3.0"
tokio = { version = "1.21", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] }
smol = { version = "1.3.0" }
-tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
\ No newline at end of file
+tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
+
+rustls = { version = "0.20.7" }
+rustls-pemfile = { version = "1.0.1" }
+webpki = { version = "0.22.0" }
+async-rustls = { version = "0.3.0" }
+tokio-rustls = "0.23.4"
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..ee6256c
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,373 @@
+Mozilla Public License Version 2.0
+==================================
+
+1. Definitions
+--------------
+
+1.1. "Contributor"
+ means each individual or legal entity that creates, contributes to
+ the creation of, or owns Covered Software.
+
+1.2. "Contributor Version"
+ means the combination of the Contributions of others (if any) used
+ by a Contributor and that particular Contributor's Contribution.
+
+1.3. "Contribution"
+ means Covered Software of a particular Contributor.
+
+1.4. "Covered Software"
+ means Source Code Form to which the initial Contributor has attached
+ the notice in Exhibit A, the Executable Form of such Source Code
+ Form, and Modifications of such Source Code Form, in each case
+ including portions thereof.
+
+1.5. "Incompatible With Secondary Licenses"
+ means
+
+ (a) that the initial Contributor has attached the notice described
+ in Exhibit B to the Covered Software; or
+
+ (b) that the Covered Software was made available under the terms of
+ version 1.1 or earlier of the License, but not also under the
+ terms of a Secondary License.
+
+1.6. "Executable Form"
+ means any form of the work other than Source Code Form.
+
+1.7. "Larger Work"
+ means a work that combines Covered Software with other material, in
+ a separate file or files, that is not Covered Software.
+
+1.8. "License"
+ means this document.
+
+1.9. "Licensable"
+ means having the right to grant, to the maximum extent possible,
+ whether at the time of the initial grant or subsequently, any and
+ all of the rights conveyed by this License.
+
+1.10. "Modifications"
+ means any of the following:
+
+ (a) any file in Source Code Form that results from an addition to,
+ deletion from, or modification of the contents of Covered
+ Software; or
+
+ (b) any new file in Source Code Form that contains any Covered
+ Software.
+
+1.11. "Patent Claims" of a Contributor
+ means any patent claim(s), including without limitation, method,
+ process, and apparatus claims, in any patent Licensable by such
+ Contributor that would be infringed, but for the grant of the
+ License, by the making, using, selling, offering for sale, having
+ made, import, or transfer of either its Contributions or its
+ Contributor Version.
+
+1.12. "Secondary License"
+ means either the GNU General Public License, Version 2.0, the GNU
+ Lesser General Public License, Version 2.1, the GNU Affero General
+ Public License, Version 3.0, or any later versions of those
+ licenses.
+
+1.13. "Source Code Form"
+ means the form of the work preferred for making modifications.
+
+1.14. "You" (or "Your")
+ means an individual or a legal entity exercising rights under this
+ License. For legal entities, "You" includes any entity that
+ controls, is controlled by, or is under common control with You. For
+ purposes of this definition, "control" means (a) the power, direct
+ or indirect, to cause the direction or management of such entity,
+ whether by contract or otherwise, or (b) ownership of more than
+ fifty percent (50%) of the outstanding shares or beneficial
+ ownership of such entity.
+
+2. License Grants and Conditions
+--------------------------------
+
+2.1. Grants
+
+Each Contributor hereby grants You a world-wide, royalty-free,
+non-exclusive license:
+
+(a) under intellectual property rights (other than patent or trademark)
+ Licensable by such Contributor to use, reproduce, make available,
+ modify, display, perform, distribute, and otherwise exploit its
+ Contributions, either on an unmodified basis, with Modifications, or
+ as part of a Larger Work; and
+
+(b) under Patent Claims of such Contributor to make, use, sell, offer
+ for sale, have made, import, and otherwise transfer either its
+ Contributions or its Contributor Version.
+
+2.2. Effective Date
+
+The licenses granted in Section 2.1 with respect to any Contribution
+become effective for each Contribution on the date the Contributor first
+distributes such Contribution.
+
+2.3. Limitations on Grant Scope
+
+The licenses granted in this Section 2 are the only rights granted under
+this License. No additional rights or licenses will be implied from the
+distribution or licensing of Covered Software under this License.
+Notwithstanding Section 2.1(b) above, no patent license is granted by a
+Contributor:
+
+(a) for any code that a Contributor has removed from Covered Software;
+ or
+
+(b) for infringements caused by: (i) Your and any other third party's
+ modifications of Covered Software, or (ii) the combination of its
+ Contributions with other software (except as part of its Contributor
+ Version); or
+
+(c) under Patent Claims infringed by Covered Software in the absence of
+ its Contributions.
+
+This License does not grant any rights in the trademarks, service marks,
+or logos of any Contributor (except as may be necessary to comply with
+the notice requirements in Section 3.4).
+
+2.4. Subsequent Licenses
+
+No Contributor makes additional grants as a result of Your choice to
+distribute the Covered Software under a subsequent version of this
+License (see Section 10.2) or under the terms of a Secondary License (if
+permitted under the terms of Section 3.3).
+
+2.5. Representation
+
+Each Contributor represents that the Contributor believes its
+Contributions are its original creation(s) or it has sufficient rights
+to grant the rights to its Contributions conveyed by this License.
+
+2.6. Fair Use
+
+This License is not intended to limit any rights You have under
+applicable copyright doctrines of fair use, fair dealing, or other
+equivalents.
+
+2.7. Conditions
+
+Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
+in Section 2.1.
+
+3. Responsibilities
+-------------------
+
+3.1. Distribution of Source Form
+
+All distribution of Covered Software in Source Code Form, including any
+Modifications that You create or to which You contribute, must be under
+the terms of this License. You must inform recipients that the Source
+Code Form of the Covered Software is governed by the terms of this
+License, and how they can obtain a copy of this License. You may not
+attempt to alter or restrict the recipients' rights in the Source Code
+Form.
+
+3.2. Distribution of Executable Form
+
+If You distribute Covered Software in Executable Form then:
+
+(a) such Covered Software must also be made available in Source Code
+ Form, as described in Section 3.1, and You must inform recipients of
+ the Executable Form how they can obtain a copy of such Source Code
+ Form by reasonable means in a timely manner, at a charge no more
+ than the cost of distribution to the recipient; and
+
+(b) You may distribute such Executable Form under the terms of this
+ License, or sublicense it under different terms, provided that the
+ license for the Executable Form does not attempt to limit or alter
+ the recipients' rights in the Source Code Form under this License.
+
+3.3. Distribution of a Larger Work
+
+You may create and distribute a Larger Work under terms of Your choice,
+provided that You also comply with the requirements of this License for
+the Covered Software. If the Larger Work is a combination of Covered
+Software with a work governed by one or more Secondary Licenses, and the
+Covered Software is not Incompatible With Secondary Licenses, this
+License permits You to additionally distribute such Covered Software
+under the terms of such Secondary License(s), so that the recipient of
+the Larger Work may, at their option, further distribute the Covered
+Software under the terms of either this License or such Secondary
+License(s).
+
+3.4. Notices
+
+You may not remove or alter the substance of any license notices
+(including copyright notices, patent notices, disclaimers of warranty,
+or limitations of liability) contained within the Source Code Form of
+the Covered Software, except that You may alter any license notices to
+the extent required to remedy known factual inaccuracies.
+
+3.5. Application of Additional Terms
+
+You may choose to offer, and to charge a fee for, warranty, support,
+indemnity or liability obligations to one or more recipients of Covered
+Software. However, You may do so only on Your own behalf, and not on
+behalf of any Contributor. You must make it absolutely clear that any
+such warranty, support, indemnity, or liability obligation is offered by
+You alone, and You hereby agree to indemnify every Contributor for any
+liability incurred by such Contributor as a result of warranty, support,
+indemnity or liability terms You offer. You may include additional
+disclaimers of warranty and limitations of liability specific to any
+jurisdiction.
+
+4. Inability to Comply Due to Statute or Regulation
+---------------------------------------------------
+
+If it is impossible for You to comply with any of the terms of this
+License with respect to some or all of the Covered Software due to
+statute, judicial order, or regulation then You must: (a) comply with
+the terms of this License to the maximum extent possible; and (b)
+describe the limitations and the code they affect. Such description must
+be placed in a text file included with all distributions of the Covered
+Software under this License. Except to the extent prohibited by statute
+or regulation, such description must be sufficiently detailed for a
+recipient of ordinary skill to be able to understand it.
+
+5. Termination
+--------------
+
+5.1. The rights granted under this License will terminate automatically
+if You fail to comply with any of its terms. However, if You become
+compliant, then the rights granted under this License from a particular
+Contributor are reinstated (a) provisionally, unless and until such
+Contributor explicitly and finally terminates Your grants, and (b) on an
+ongoing basis, if such Contributor fails to notify You of the
+non-compliance by some reasonable means prior to 60 days after You have
+come back into compliance. Moreover, Your grants from a particular
+Contributor are reinstated on an ongoing basis if such Contributor
+notifies You of the non-compliance by some reasonable means, this is the
+first time You have received notice of non-compliance with this License
+from such Contributor, and You become compliant prior to 30 days after
+Your receipt of the notice.
+
+5.2. If You initiate litigation against any entity by asserting a patent
+infringement claim (excluding declaratory judgment actions,
+counter-claims, and cross-claims) alleging that a Contributor Version
+directly or indirectly infringes any patent, then the rights granted to
+You by any and all Contributors for the Covered Software under Section
+2.1 of this License shall terminate.
+
+5.3. In the event of termination under Sections 5.1 or 5.2 above, all
+end user license agreements (excluding distributors and resellers) which
+have been validly granted by You or Your distributors under this License
+prior to termination shall survive termination.
+
+************************************************************************
+* *
+* 6. Disclaimer of Warranty *
+* ------------------------- *
+* *
+* Covered Software is provided under this License on an "as is" *
+* basis, without warranty of any kind, either expressed, implied, or *
+* statutory, including, without limitation, warranties that the *
+* Covered Software is free of defects, merchantable, fit for a *
+* particular purpose or non-infringing. The entire risk as to the *
+* quality and performance of the Covered Software is with You. *
+* Should any Covered Software prove defective in any respect, You *
+* (not any Contributor) assume the cost of any necessary servicing, *
+* repair, or correction. This disclaimer of warranty constitutes an *
+* essential part of this License. No use of any Covered Software is *
+* authorized under this License except under this disclaimer. *
+* *
+************************************************************************
+
+************************************************************************
+* *
+* 7. Limitation of Liability *
+* -------------------------- *
+* *
+* Under no circumstances and under no legal theory, whether tort *
+* (including negligence), contract, or otherwise, shall any *
+* Contributor, or anyone who distributes Covered Software as *
+* permitted above, be liable to You for any direct, indirect, *
+* special, incidental, or consequential damages of any character *
+* including, without limitation, damages for lost profits, loss of *
+* goodwill, work stoppage, computer failure or malfunction, or any *
+* and all other commercial damages or losses, even if such party *
+* shall have been informed of the possibility of such damages. This *
+* limitation of liability shall not apply to liability for death or *
+* personal injury resulting from such party's negligence to the *
+* extent applicable law prohibits such limitation. Some *
+* jurisdictions do not allow the exclusion or limitation of *
+* incidental or consequential damages, so this exclusion and *
+* limitation may not apply to You. *
+* *
+************************************************************************
+
+8. Litigation
+-------------
+
+Any litigation relating to this License may be brought only in the
+courts of a jurisdiction where the defendant maintains its principal
+place of business and such litigation shall be governed by laws of that
+jurisdiction, without reference to its conflict-of-law provisions.
+Nothing in this Section shall prevent a party's ability to bring
+cross-claims or counter-claims.
+
+9. Miscellaneous
+----------------
+
+This License represents the complete agreement concerning the subject
+matter hereof. If any provision of this License is held to be
+unenforceable, such provision shall be reformed only to the extent
+necessary to make it enforceable. Any law or regulation which provides
+that the language of a contract shall be construed against the drafter
+shall not be used to construe this License against a Contributor.
+
+10. Versions of the License
+---------------------------
+
+10.1. New Versions
+
+Mozilla Foundation is the license steward. Except as provided in Section
+10.3, no one other than the license steward has the right to modify or
+publish new versions of this License. Each version will be given a
+distinguishing version number.
+
+10.2. Effect of New Versions
+
+You may distribute the Covered Software under the terms of the version
+of the License under which You originally received the Covered Software,
+or under the terms of any subsequent version published by the license
+steward.
+
+10.3. Modified Versions
+
+If you create software not governed by this License, and you want to
+create a new license for such software, you may create and use a
+modified version of this License if you rename the license and remove
+any references to the name of the license steward (except to note that
+such modified license differs from this License).
+
+10.4. Distributing Source Code Form that is Incompatible With Secondary
+Licenses
+
+If You choose to distribute Source Code Form that is Incompatible With
+Secondary Licenses under the terms of this version of the License, the
+notice described in Exhibit B of this License must be attached.
+
+Exhibit A - Source Code Form License Notice
+-------------------------------------------
+
+ This Source Code Form is subject to the terms of the Mozilla Public
+ License, v. 2.0. If a copy of the MPL was not distributed with this
+ file, You can obtain one at https://mozilla.org/MPL/2.0/.
+
+If it is not possible or desirable to put the notice in a particular
+file, then You may include the notice in a location (such as a LICENSE
+file in a relevant directory) where a recipient would be likely to look
+for such a notice.
+
+You may add additional accurate notices of copyright ownership.
+
+Exhibit B - "Incompatible With Secondary Licenses" Notice
+---------------------------------------------------------
+
+ This Source Code Form is "Incompatible With Secondary Licenses", as
+ defined by the Mozilla Public License, v. 2.0.
diff --git a/README.md b/README.md
index 89f320c..2b3f334 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,25 @@
-# mqrstt
-Async MQTT client over TCP and QUIC
+
+
+# `📟 mqrstt`
+
+[](https://crates.io/crates/mqrstt)
+[](https://docs.rs/mqrstt)
+[](https://deps.rs/repo/github/GunnarMorrigan/mqrstt)
+[](https://codecov.io/github/GunnarMorrigan/mqrstt)
+
+`mqrstt` is an MQTTv5 client implementation that follows the [sans-io](https://sans-io.readthedocs.io/) approach.
+
+
+
+
+## License
+
+Licensed under
+
+* Mozilla Public License, Version 2.0, ([MPL-2.0](https://choosealicense.com/licenses/mpl-2.0/)
+
+## Contribution
+
+Unless you explicitly state otherwise, any contribution intentionally
+submitted for inclusion in the work by you, shall be licensed under MPL-2.0, without any additional terms or
+conditions.
\ No newline at end of file
diff --git a/deny.toml b/deny.toml
index 33fd274..b24eaba 100644
--- a/deny.toml
+++ b/deny.toml
@@ -8,8 +8,23 @@ unlicensed = "deny"
allow-osi-fsf-free = "neither"
copyleft = "deny"
confidence-threshold = 0.95
-allow = ["Apache-2.0", "MIT", "BSD-3-Clause", "ISC"]
+allow = ["MPL-2.0", "Apache-2.0", "MIT", "BSD-3-Clause", "ISC"]
exceptions = [
{ allow = ["Unicode-DFS-2016"], name = "unicode-ident" },
+ { allow = ["OpenSSL"], name = "ring" }
+]
+
+[[licenses.clarify]]
+name = "ring"
+expression = "MIT AND ISC AND OpenSSL"
+license-files = [
+ { path = "LICENSE", hash = 0xbd0eed23 }
+]
+
+[[licenses.clarify]]
+name = "webpki"
+expression = "ISC"
+license-files = [
+ { path = "LICENSE", hash = 0x001c7e6c },
]
\ No newline at end of file
diff --git a/rust-toolchain.toml b/rust-toolchain.toml
deleted file mode 100644
index 271800c..0000000
--- a/rust-toolchain.toml
+++ /dev/null
@@ -1,2 +0,0 @@
-[toolchain]
-channel = "nightly"
\ No newline at end of file
diff --git a/rustfmt.toml b/rustfmt.toml
new file mode 100644
index 0000000..e672bfc
--- /dev/null
+++ b/rustfmt.toml
@@ -0,0 +1,3 @@
+unstable_features = true
+brace_style = "PreferSameLine"
+control_brace_style = "ClosingNextLine"
\ No newline at end of file
diff --git a/src/available_packet_ids.rs b/src/available_packet_ids.rs
index 9502573..a6c2d80 100644
--- a/src/available_packet_ids.rs
+++ b/src/available_packet_ids.rs
@@ -1,5 +1,5 @@
use async_channel::{Receiver, Sender};
-use tracing::{debug, error};
+use tracing::error;
use crate::error::MqttError;
@@ -20,21 +20,21 @@ impl AvailablePacketIds {
(apkid, r)
}
- pub fn try_mark_available(&self, pkid: u16) -> Result<(), MqttError> {
- match self.sender.try_send(pkid) {
- Ok(_) => {
- Ok(())
- // debug!("Marked packet id as available: {}", pkid);
- }
- Err(err) => {
- error!(
- "Encountered an error while marking an packet id as available. Error: {}",
- err
- );
- Err(MqttError::PacketIdError(err.into_inner()))
- }
- }
- }
+ // pub fn try_mark_available(&self, pkid: u16) -> Result<(), MqttError> {
+ // match self.sender.try_send(pkid) {
+ // Ok(_) => {
+ // Ok(())
+ // // debug!("Marked packet id as available: {}", pkid);
+ // }
+ // Err(err) => {
+ // error!(
+ // "Encountered an error while marking an packet id as available. Error: {}",
+ // err
+ // );
+ // Err(MqttError::PacketIdError(err.into_inner()))
+ // }
+ // }
+ // }
pub async fn mark_available(&self, pkid: u16) -> Result<(), MqttError> {
match self.sender.send(pkid).await {
diff --git a/src/client.rs b/src/client.rs
index c0848ff..c4bd4c2 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -5,13 +5,10 @@ use tracing::info;
use crate::{
error::ClientError,
packets::{
- {Disconnect, DisconnectProperties},
- Packet,
- {Publish, PublishProperties},
reason_codes::DisconnectReasonCode,
+ Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties},
{Subscribe, SubscribeProperties, Subscription},
{Unsubscribe, UnsubscribeProperties, UnsubscribeTopics},
- QoS,
},
};
@@ -115,7 +112,8 @@ impl AsyncClient {
"Published message into network_packet_sender. len {}",
self.to_network_s.len()
);
- } else {
+ }
+ else {
self.client_to_handler_s
.send(Packet::Publish(publish))
.await
@@ -159,7 +157,8 @@ impl AsyncClient {
.send(Packet::Publish(publish))
.await
.map_err(|_| ClientError::NoHandler)?;
- } else {
+ }
+ else {
self.client_to_handler_s
.send(Packet::Publish(publish))
.await
diff --git a/src/connect_options.rs b/src/connect_options.rs
index 718a65d..85c560c 100644
--- a/src/connect_options.rs
+++ b/src/connect_options.rs
@@ -7,16 +7,6 @@ use crate::util::constants::RECEIVE_MAXIMUM_DEFAULT;
#[derive(Debug, Clone)]
pub struct ConnectOptions {
- /// broker address that you want to connect to
- pub(crate) address: String,
- /// broker port
- pub(crate) port: u16,
-
- #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))]
- pub(crate) tls_config: Option,
-
- // /// What transport protocol to use
- // transport: Transport,
/// keep alive time to send pingreq to broker when the connection is idle
pub keep_alive_interval_s: u64,
pub connection_timeout_s: u64,
@@ -28,38 +18,33 @@ pub struct ConnectOptions {
pub username: Option,
pub password: Option,
/// request (publish, subscribe) channel capacity
- channel_capacity: usize,
+ pub channel_capacity: usize,
/// Minimum delay time between consecutive outgoing packets
/// while retransmitting pending packets
// TODO! IMPLEMENT THIS!
- pending_throttle_s: u64,
+ pub pending_throttle_s: u64,
- send_reason_messages: bool,
+ pub send_reason_messages: bool,
// MQTT v5 Connect Properties:
- session_expiry_interval: Option,
- pub(crate) receive_maximum: Option,
- maximum_packet_size: Option,
- topic_alias_maximum: Option,
- request_response_information: Option,
- request_problem_information: Option,
- user_properties: Vec<(String, String)>,
- authentication_method: Option,
- authentication_data: Bytes,
+ pub session_expiry_interval: Option,
+ pub receive_maximum: Option,
+ pub maximum_packet_size: Option,
+ pub topic_alias_maximum: Option,
+ pub request_response_information: Option,
+ pub request_problem_information: Option,
+ pub user_properties: Vec<(String, String)>,
+ pub authentication_method: Option,
+ pub authentication_data: Bytes,
/// Last will that will be issued on unexpected disconnect
- last_will: Option,
+ pub last_will: Option,
}
impl ConnectOptions {
- pub fn new(address: String, port: u16, client_id: String) -> Self {
+ pub fn new(client_id: String) -> Self {
Self {
- address,
- port,
- #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))]
- tls_config: None,
-
keep_alive_interval_s: 60,
connection_timeout_s: 30,
clean_session: false,
@@ -83,46 +68,6 @@ impl ConnectOptions {
}
}
- #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))]
- pub(crate) fn new_with_tls_config(
- address: String,
- port: u16,
- client_id: String,
- tls_config: Option,
- ) -> Self {
- Self {
- address,
- port,
- tls_config,
-
- keep_alive_interval_s: 60,
- connection_timeout_s: 30,
- clean_session: false,
- client_id,
- credentials: None,
- channel_capacity: 100,
- pending_throttle_s: 30,
- send_reason_messages: false,
-
- session_expiry_interval: None,
- receive_maximum: None,
- maximum_packet_size: None,
- topic_alias_maximum: None,
- request_response_information: None,
- request_problem_information: None,
- user_properties: vec![],
- authentication_method: None,
- authentication_data: Bytes::new(),
- last_will: None,
- }
- }
-
- pub fn set_address(&mut self, address: String) {
- self.address = address
- }
- pub fn set_port(&mut self, port: u16) {
- self.port = port
- }
pub fn set_keep_alive_interval_s(&mut self, keep_alive_interval_s: u64) {
self.keep_alive_interval_s = keep_alive_interval_s
}
@@ -141,17 +86,12 @@ impl ConnectOptions {
pub fn set_pending_throttle_s(&mut self, pending_throttle_s: u64) {
self.pending_throttle_s = pending_throttle_s
}
- pub fn set_send_reason_messages(&mut self, send_reason_messages: bool) {
- self.send_reason_messages = send_reason_messages
- }
-
pub fn set_session_expiry_interval(&mut self, session_expiry_interval: u32) {
self.session_expiry_interval = Some(session_expiry_interval)
}
pub fn clear_session_expiry_interval(&mut self) {
self.session_expiry_interval = None
}
-
pub fn set_receive_maximum(&mut self, receive_maximum: u16) {
self.receive_maximum = Some(receive_maximum)
}
@@ -161,7 +101,6 @@ impl ConnectOptions {
pub fn receive_maximum(&self) -> u16 {
self.receive_maximum.unwrap_or(RECEIVE_MAXIMUM_DEFAULT)
}
-
pub fn set_maximum_packet_size(&mut self, maximum_packet_size: u32) {
self.maximum_packet_size = Some(maximum_packet_size)
}
@@ -175,13 +114,13 @@ impl ConnectOptions {
self.topic_alias_maximum = None
}
pub fn set_request_response_information(&mut self, request_response_information: bool) {
- self.request_response_information = Some(if request_response_information { 1 } else { 0 })
+ self.request_response_information = Some(u8::from(request_response_information))
}
pub fn clear_request_response_information(&mut self) {
self.request_response_information = None
}
pub fn set_request_problem_information(&mut self, request_problem_information: bool) {
- self.request_problem_information = Some(if request_problem_information { 1 } else { 0 })
+ self.request_problem_information = Some(u8::from(request_problem_information))
}
pub fn clear_request_problem_information(&mut self) {
self.request_problem_information = None
diff --git a/src/connections/async_native_tls.rs b/src/connections/async_native_tls.rs
deleted file mode 100644
index 8ab5e56..0000000
--- a/src/connections/async_native_tls.rs
+++ /dev/null
@@ -1,236 +0,0 @@
-use std::io::{self, Error, ErrorKind};
-
-use async_channel::Receiver;
-use async_native_tls::{TlsConnector, TlsStream};
-use bytes::{Buf, BytesMut};
-use smol::io::{ReadHalf, WriteHalf};
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt},
- net::TcpStream,
-};
-
-use tracing::trace;
-
-use crate::error::TlsError;
-use crate::{
- connect_options::ConnectOptions, connections::create_connect_from_options,
- error::ConnectionError, network::Incoming,
-};
-use crate::{
- connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite},
- packets::{
- error::ReadBytes,
- packets::{FixedHeader, Packet, PacketType},
- reason_codes::ConnAckReasonCode,
- },
-};
-
-#[derive(Debug)]
-pub struct TlsReader {
- readhalf: ReadHalf>,
-
- /// Buffered reads
- buffer: BytesMut,
-}
-
-impl TlsReader {
- pub async fn new(options: &ConnectOptions) -> Result<(TlsReader, TlsWriter), ConnectionError> {
- if let Some(tls_config) = &options.tls_config {
- let addr = options.address.clone();
- let tcp = TcpStream::connect((addr.as_str(), options.port)).await?;
-
- let connector = TlsConnector::new().use_sni(true);
-
- let mut connection = async_native_tls::connect("google.com", tcp).await.unwrap();
-
- let (readhalf, writehalf) = split(connection);
-
- let reader = Self {
- readhalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- };
- let writer = TlsWriter::new(writehalf);
-
- Ok((reader, writer))
- } else {
- Err(ConnectionError::TLS(TlsError::NoTlsConfig))
- }
- }
-
- pub async fn read(&mut self) -> io::Result {
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
-
- return Packet::read(header, buf.into())
- .map_err(|err| Error::new(ErrorKind::InvalidData, err));
- }
- }
-
- /// Reads more than 'required' bytes to frame a packet into self.read buffer
- async fn read_bytes(&mut self, required: usize) -> io::Result {
- let mut total_read = 0;
-
- loop {
- #[cfg(feature = "tokio")]
- let read = self.readhalf.read_buf(&mut self.buffer).await?;
- #[cfg(feature = "smol")]
- let read = self.readhalf.read(&mut self.buffer).await?;
- if 0 == read {
- return if self.buffer.is_empty() {
- Err(io::Error::new(
- io::ErrorKind::ConnectionAborted,
- "Connection closed by peer",
- ))
- } else {
- Err(io::Error::new(
- io::ErrorKind::ConnectionReset,
- "Connection reset by peer",
- ))
- };
- }
-
- total_read += read;
- if total_read >= required {
- return Ok(total_read);
- }
- }
- }
-}
-
-impl AsyncMqttNetworkRead for TlsReader {
- type W = TlsWriter;
-
- fn connect(
- options: &ConnectOptions,
- ) -> impl std::future::Future> + Send + '_
- {
- async {
- let (mut reader, mut writer) = TlsReader::new(options).await?;
-
- let mut buf_out = BytesMut::new();
-
- create_connect_from_options(options).write(&mut buf_out)?;
-
- writer.write_buffer(&mut buf_out).await?;
-
- let packet = reader.read().await?;
- if let Packet::ConnAck(con) = packet {
- if con.reason_code == ConnAckReasonCode::Success {
- Ok((reader, writer, Packet::ConnAck(con)))
- } else {
- Err(ConnectionError::ConnectionRefused(con.reason_code))
- }
- } else {
- Err(ConnectionError::NotConnAck(packet))
- }
- }
- }
-
- async fn read(&mut self) -> Result {
- Ok(self.read().await?)
- }
-
- async fn read_direct(
- &mut self,
- incoming_packet_sender: &async_channel::Sender,
- ) -> Result {
- let mut read_packets = 0;
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
- let read_packet = Packet::read(header, buf.into())?;
- tracing::trace!("Read packet from network {}", read_packet);
- let disconnect = read_packet.packet_type() == PacketType::Disconnect;
- incoming_packet_sender.send(read_packet).await?;
- if disconnect {
- return Ok(true);
- }
- read_packets += 1;
- if read_packets >= 10 {
- return Ok(false);
- }
- }
- }
-}
-pub struct TlsWriter {
- writehalf: WriteHalf>,
-
- buffer: BytesMut,
-}
-
-impl TlsWriter {
- pub fn new(writehalf: WriteHalf>) -> Self {
- Self {
- writehalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- }
- }
-}
-
-impl AsyncMqttNetworkWrite for TlsWriter {
- async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> {
- if buffer.is_empty() {
- return Ok(());
- }
-
- self.writehalf.write_all(&buffer[..]).await?;
- buffer.clear();
- Ok(())
- }
-
- async fn write(&mut self, outgoing: &Receiver) -> Result {
- let mut disconnect = false;
-
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- }
-
- while !outgoing.is_empty() && !disconnect {
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- break;
- }
- trace!("Going to write packet to network: {:?}", packet);
- }
-
- self.writehalf.write_all(&self.buffer[..]).await?;
- self.writehalf.flush().await?;
- self.buffer.clear();
- Ok(disconnect)
- }
-}
diff --git a/src/connections/async_rustls.rs b/src/connections/async_rustls.rs
deleted file mode 100644
index 2e6f9f5..0000000
--- a/src/connections/async_rustls.rs
+++ /dev/null
@@ -1,297 +0,0 @@
-use std::io::{self, Error, ErrorKind};
-
-use async_channel::Receiver;
-
-use async_rustls::client::TlsStream;
-use async_rustls::TlsConnector;
-use bytes::{Buf, BytesMut};
-use rustls::ServerName;
-use smol::io::{ReadHalf, WriteHalf};
-use smol::{
- io::{split, AsyncReadExt, AsyncWriteExt},
- net::TcpStream,
-};
-
-use tracing::trace;
-
-use crate::error::TlsError;
-use crate::{
- connect_options::ConnectOptions, connections::create_connect_from_options,
- error::ConnectionError, network::Incoming,
-};
-use crate::{
- connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite},
- packets::{
- error::ReadBytes,
- {FixedHeader, Packet, PacketType},
- reason_codes::ConnAckReasonCode,
- },
-};
-
-use super::transport::{RustlsConfig, TlsConfig};
-use super::util::simple_rust_tls;
-
-#[derive(Debug)]
-pub struct TlsReader {
- readhalf: ReadHalf>,
-
- /// Input buffer
- const_buffer: [u8; 1000],
- /// Buffered reads
- buffer: BytesMut,
-}
-
-impl TlsReader {
- pub async fn new(options: &ConnectOptions) -> Result<(TlsReader, TlsWriter), TlsError> {
- if let Some(tls_config) = &options.tls_config {
- let arc_tls_config = match tls_config {
- TlsConfig::Rustls(RustlsConfig::Simple {
- ca,
- alpn,
- client_auth,
- }) => simple_rust_tls(ca.clone(), alpn.clone(), client_auth.clone())?,
- TlsConfig::Rustls(RustlsConfig::Rustls(config)) => config.clone(),
- };
-
- let domain = ServerName::try_from(options.address.as_str())?;
- let connector = TlsConnector::from(arc_tls_config);
-
- let stream = TcpStream::connect((options.address.as_str(), options.port)).await?;
- let connection = connector.connect(domain, stream).await?;
-
- let (readhalf, writehalf) = split(connection);
-
- let reader = Self {
- readhalf,
- const_buffer: [0; 1000],
- buffer: BytesMut::with_capacity(20 * 1024),
- };
- let writer = TlsWriter::new(writehalf);
-
- Ok((reader, writer))
- } else {
- Err(TlsError::NoTlsConfig)
- }
- }
-
- pub async fn read(&mut self) -> io::Result {
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
-
- return Packet::read(header, buf.into())
- .map_err(|err| Error::new(ErrorKind::InvalidData, err));
- }
- }
-
- /// Reads more than 'required' bytes to frame a packet into self.read buffer
- async fn read_bytes(&mut self, required: usize) -> io::Result {
- let mut total_read = 0;
-
- loop {
- let read = self.readhalf.read(&mut self.const_buffer).await?;
-
- if read == 0 {
- return if self.buffer.is_empty() {
- Err(io::Error::new(
- io::ErrorKind::ConnectionAborted,
- "Connection closed by peer",
- ))
- } else {
- Err(io::Error::new(
- io::ErrorKind::ConnectionReset,
- "Connection reset by peer",
- ))
- };
- }
- else {
- self.buffer.extend_from_slice(&self.const_buffer[0..read]);
- }
-
- total_read += read;
- if total_read >= required {
- return Ok(total_read);
- }
- }
- }
-}
-
-impl AsyncMqttNetworkRead for TlsReader {
- type W = TlsWriter;
-
- fn connect(
- options: &ConnectOptions,
- ) -> impl std::future::Future> + Send + '_
- {
- async {
- let (mut reader, mut writer) = TlsReader::new(options).await?;
-
- let mut buf_out = BytesMut::new();
-
- create_connect_from_options(options).write(&mut buf_out)?;
-
- writer.write_buffer(&mut buf_out).await?;
-
- let packet = reader.read().await?;
- if let Packet::ConnAck(con) = packet {
- if con.reason_code == ConnAckReasonCode::Success {
- Ok((reader, writer, Packet::ConnAck(con)))
- } else {
- Err(ConnectionError::ConnectionRefused(con.reason_code))
- }
- } else {
- Err(ConnectionError::NotConnAck(packet))
- }
- }
- }
-
- async fn read(&mut self) -> Result {
- Ok(self.read().await?)
- }
-
- async fn read_direct(
- &mut self,
- incoming_packet_sender: &async_channel::Sender,
- ) -> Result {
- let mut read_packets = 0;
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
- let read_packet = Packet::read(header, buf.into())?;
- tracing::trace!("Read packet from network {}", read_packet);
- let disconnect = read_packet.packet_type() == PacketType::Disconnect;
- incoming_packet_sender.send(read_packet).await?;
- if disconnect {
- return Ok(true);
- }
- read_packets += 1;
- if read_packets >= 10 {
- return Ok(false);
- }
- }
- }
-}
-
-#[derive(Debug)]
-pub struct TlsWriter {
- writehalf: WriteHalf>,
-
- buffer: BytesMut,
-}
-
-impl TlsWriter {
- pub fn new(writehalf: WriteHalf>) -> Self {
- Self {
- writehalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- }
- }
-}
-
-impl AsyncMqttNetworkWrite for TlsWriter {
- async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> {
- if buffer.is_empty() {
- return Ok(());
- }
-
- self.writehalf.write_all(&buffer[..]).await?;
- self.writehalf.flush().await?;
- buffer.clear();
- Ok(())
- }
-
- async fn write(&mut self, outgoing: &Receiver) -> Result {
- let mut disconnect = false;
-
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- }
-
- while !outgoing.is_empty() && !disconnect {
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- break;
- }
- trace!("Going to write packet to network: {:?}", packet);
- }
-
- self.writehalf.write_all(&self.buffer[..]).await?;
- self.writehalf.flush().await?;
- self.buffer.clear();
- Ok(disconnect)
- }
-}
-
-#[cfg(test)]
-mod test {
- use crate::{
- connect_options::ConnectOptions,
- connections::{
- transport::{RustlsConfig, TlsConfig},
- AsyncMqttNetworkRead,
- },
- packets::{
- {Packet, PacketType},
- reason_codes::ConnAckReasonCode,
- },
- tests::resources::EMQX_CERT,
- };
-
- use super::TlsReader;
-
- fn connect_emqx_test() {
- let config = TlsConfig::Rustls(RustlsConfig::Simple {
- ca: EMQX_CERT.to_vec(),
- alpn: None,
- client_auth: None,
- });
-
- let opt = ConnectOptions::new_with_tls_config(
- "broker.emqx.io".to_string(),
- 8883,
- "test123123".to_string(),
- Some(config),
- );
-
- let (_, _, packet) = smol::block_on(TlsReader::connect(&opt)).unwrap();
-
- assert_eq!(PacketType::ConnAck, packet.packet_type());
- if let Packet::ConnAck(conn) = packet {
- assert_eq!(ConnAckReasonCode::Success, conn.reason_code);
- }
- }
-}
diff --git a/src/connections/mod.rs b/src/connections/mod.rs
index 1a15d6a..740a9ac 100644
--- a/src/connections/mod.rs
+++ b/src/connections/mod.rs
@@ -1,77 +1,26 @@
-#[cfg(all(feature = "tokio", feature = "tcp"))]
-pub mod tokio_tcp;
-
-#[cfg(all(feature = "smol", feature = "tcp"))]
-pub mod smol_tcp;
-
-#[cfg(all(feature = "smol", feature = "native-tls"))]
-pub mod async_native_tls;
-
-#[cfg(all(feature = "smol", feature = "smol-rustls"))]
-pub mod async_rustls;
-
#[cfg(all(feature = "quic"))]
pub mod quic;
-
-#[cfg(all(feature = "tokio", feature = "tokio-rustls"))]
-pub mod tokio_rustls;
-
-#[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))]
-pub mod transport;
-
-#[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))]
-mod util;
-
-use std::future::Future;
-
-use async_channel::{Receiver, Sender};
-use bytes::BytesMut;
+#[cfg(feature = "smol")]
+pub mod smol_stream;
+#[cfg(feature = "tokio")]
+pub mod tokio_stream;
use crate::connect_options::ConnectOptions;
-use crate::error::ConnectionError;
use crate::packets::Connect;
use crate::packets::Packet;
pub fn create_connect_from_options(options: &ConnectOptions) -> Packet {
- let mut connect = Connect::default();
-
- connect.client_id = options.client_id.clone();
- connect.clean_session = options.clean_session;
- connect.keep_alive = options.keep_alive_interval_s as u16;
- connect.connect_properties.request_problem_information = Some(1u8);
- connect.connect_properties.request_response_information = Some(1u8);
- connect.username = options.username.clone();
- connect.password = options.password.clone();
+ let mut connect = Connect {
+ client_id: options.client_id.clone(),
+ clean_session: options.clean_session,
+ keep_alive: options.keep_alive_interval_s as u16,
+ username: options.username.clone(),
+ password: options.password.clone(),
+ ..Default::default()
+ };
+
+ connect.connect_properties.request_problem_information = options.request_problem_information;
+ connect.connect_properties.request_response_information = options.request_response_information;
Packet::Connect(connect)
}
-
-pub trait AsyncMqttNetwork: Sized + Sync + 'static {
- fn connect(
- options: &ConnectOptions,
- ) -> impl Future> + Send + '_;
-
- async fn read(&self) -> Result;
-
- async fn read_many(&self, receiver: &Sender) -> Result<(), ConnectionError>;
-
- async fn write(&self, write_buf: &mut BytesMut) -> Result<(), ConnectionError>;
-}
-
-pub trait AsyncMqttNetworkRead: Sized + Sync {
- type W;
-
- fn connect(
- options: &ConnectOptions,
- ) -> impl Future> + Send + '_;
-
- async fn read(&mut self) -> Result;
-
- async fn read_direct(&mut self, sender: &Sender) -> Result;
-}
-
-pub trait AsyncMqttNetworkWrite: Sized + Sync {
- async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError>;
-
- async fn write(&mut self, outgoing: &Receiver) -> Result;
-}
diff --git a/src/connections/smol_stream.rs b/src/connections/smol_stream.rs
new file mode 100644
index 0000000..f3b4cc5
--- /dev/null
+++ b/src/connections/smol_stream.rs
@@ -0,0 +1,176 @@
+use std::io::{self, Error, ErrorKind};
+
+use bytes::{Buf, BytesMut};
+use smol::io::{AsyncReadExt, AsyncWriteExt};
+
+use futures::{AsyncRead, AsyncWrite};
+
+use tracing::trace;
+
+use crate::packets::{
+ error::ReadBytes,
+ reason_codes::ConnAckReasonCode,
+ {FixedHeader, Packet, PacketType},
+};
+use crate::{
+ connect_options::ConnectOptions, connections::create_connect_from_options,
+ error::ConnectionError,
+};
+
+#[derive(Debug)]
+pub struct SmolStream {
+ pub stream: S,
+
+ /// Input buffer
+ const_buffer: [u8; 1000],
+ /// Buffered reads
+ buffer: BytesMut,
+}
+
+impl SmolStream
+where
+ S: AsyncRead + AsyncWrite + Sized + Unpin,
+{
+ pub async fn connect(
+ options: &ConnectOptions,
+ stream: S,
+ ) -> Result<(Self, Packet), ConnectionError> {
+ let mut s = Self {
+ stream,
+ const_buffer: [0; 1000],
+ buffer: BytesMut::new(),
+ };
+
+ let mut buf_out = BytesMut::new();
+
+ create_connect_from_options(options).write(&mut buf_out)?;
+
+ s.write_buffer(&mut buf_out).await?;
+
+ let packet = s.read().await?;
+ if let Packet::ConnAck(con) = packet {
+ if con.reason_code == ConnAckReasonCode::Success {
+ Ok((s, Packet::ConnAck(con)))
+ }
+ else {
+ Err(ConnectionError::ConnectionRefused(con.reason_code))
+ }
+ }
+ else {
+ Err(ConnectionError::NotConnAck(packet))
+ }
+ }
+
+ pub async fn parse_messages(
+ &mut self,
+ incoming_packet_sender: &async_channel::Sender,
+ ) -> Result, ReadBytes> {
+ let mut ret_packet_type = None;
+ loop {
+ if self.buffer.is_empty() {
+ return Ok(ret_packet_type);
+ }
+ let (header, header_length) = FixedHeader::read_fixed_header(self.buffer.iter())?;
+
+ if header.remaining_length > self.buffer.len() {
+ return Err(ReadBytes::InsufficientBytes(
+ header.remaining_length - self.buffer.len(),
+ ));
+ }
+
+ self.buffer.advance(header_length);
+
+ let buf = self.buffer.split_to(header.remaining_length);
+ let read_packet = Packet::read(header, buf.into())?;
+ tracing::trace!("Read packet from network {}", read_packet);
+ let packet_type = read_packet.packet_type();
+ incoming_packet_sender.send(read_packet).await?;
+
+ match packet_type {
+ PacketType::Disconnect => return Ok(Some(PacketType::Disconnect)),
+ PacketType::PingResp => return Ok(Some(PacketType::PingResp)),
+ packet_type => ret_packet_type = Some(packet_type),
+ }
+ }
+ }
+
+ pub async fn read(&mut self) -> io::Result {
+ loop {
+ let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
+ Ok(header) => header,
+ Err(ReadBytes::InsufficientBytes(required_len)) => {
+ self.read_required_bytes(required_len).await?;
+ continue;
+ }
+ Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)),
+ };
+
+ self.buffer.advance(header_length);
+
+ if header.remaining_length > self.buffer.len() {
+ self.read_required_bytes(header.remaining_length - self.buffer.len())
+ .await?;
+ }
+
+ let buf = self.buffer.split_to(header.remaining_length);
+
+ return Packet::read(header, buf.into())
+ .map_err(|err| Error::new(ErrorKind::InvalidData, err));
+ }
+ }
+
+ pub async fn read_bytes(&mut self) -> io::Result {
+ let read = self.stream.read(&mut self.const_buffer).await?;
+ if 0 == read {
+ if self.buffer.is_empty() {
+ Err(io::Error::new(
+ io::ErrorKind::ConnectionAborted,
+ "Connection closed by peer",
+ ))
+ }
+ else {
+ Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ "Connection reset by peer",
+ ))
+ }
+ }
+ else {
+ self.buffer.extend_from_slice(&self.const_buffer[0..read]);
+ Ok(read)
+ }
+ }
+
+ /// Reads more than 'required' bytes to frame a packet into self.read buffer
+ pub async fn read_required_bytes(&mut self, required: usize) -> io::Result {
+ let mut total_read = 0;
+
+ loop {
+ let read = self.read_bytes().await?;
+ total_read += read;
+ if total_read >= required {
+ return Ok(total_read);
+ }
+ }
+ }
+
+ pub async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> {
+ if buffer.is_empty() {
+ return Ok(());
+ }
+
+ self.stream.write_all(&buffer[..]).await?;
+ buffer.clear();
+ Ok(())
+ }
+
+ pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> {
+ packet.write(&mut self.buffer)?;
+ trace!("Sending packet {}", packet);
+
+ self.stream.write_all(&self.buffer[..]).await?;
+ self.stream.flush().await?;
+ self.buffer.clear();
+ Ok(())
+ }
+}
diff --git a/src/connections/smol_tcp.rs b/src/connections/smol_tcp.rs
deleted file mode 100644
index 17ebaf6..0000000
--- a/src/connections/smol_tcp.rs
+++ /dev/null
@@ -1,220 +0,0 @@
-
-use std::io::{self, Error, ErrorKind};
-
-use async_channel::Receiver;
-use bytes::{Buf, BytesMut};
-// #[cfg(feature = "smol")]
-use smol::{
- io::{AsyncReadExt, AsyncWriteExt, split, ReadHalf, WriteHalf},
- net::TcpStream,
-};
-use tracing::trace;
-
-use crate::packets::{
- error::ReadBytes,
- {FixedHeader, Packet, PacketType},
- reason_codes::ConnAckReasonCode,
-};
-use crate::{
- connect_options::ConnectOptions,
- connections::{create_connect_from_options, AsyncMqttNetworkWrite},
- error::ConnectionError,
- network::Incoming,
-};
-
-use super::AsyncMqttNetworkRead;
-
-#[derive(Debug)]
-pub struct TcpReader {
- readhalf: ReadHalf,
-
- /// Input buffer
- const_buffer: [u8; 1000],
- /// Buffered reads
- buffer: BytesMut,
-}
-
-impl TcpReader {
- pub async fn new_tcp(
- options: &ConnectOptions,
- ) -> Result<(TcpReader, TcpWriter), ConnectionError> {
- let (readhalf, writehalf) = split(TcpStream::connect((options.address.clone(), options.port)).await?);
- let reader = TcpReader {
- readhalf,
- const_buffer: [0; 1000],
- buffer: BytesMut::with_capacity(20 * 1024),
- // max_incoming_size: u32::MAX as usize,
- };
- let writer = TcpWriter::new(writehalf);
- Ok((reader, writer))
- }
-
- pub async fn read(&mut self) -> io::Result {
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
-
- return Packet::read(header, buf.into())
- .map_err(|err| Error::new(ErrorKind::InvalidData, err));
- }
- }
-
- /// Reads more than 'required' bytes to frame a packet into self.read buffer
- async fn read_bytes(&mut self, required: usize) -> io::Result {
- let mut total_read = 0;
-
- loop {
- let read = self.readhalf.read(&mut self.const_buffer).await?;
- if 0 == read {
- return if self.buffer.is_empty() {
- Err(io::Error::new(
- io::ErrorKind::ConnectionAborted,
- "Connection closed by peer",
- ))
- } else {
- Err(io::Error::new(
- io::ErrorKind::ConnectionReset,
- "Connection reset by peer",
- ))
- };
- }
- else {
- self.buffer.extend_from_slice(&self.const_buffer[0..read]);
- }
-
- total_read += read;
- if total_read >= required {
- return Ok(total_read);
- }
- }
- }
-}
-
-impl AsyncMqttNetworkRead for TcpReader {
- type W = TcpWriter;
-
- fn connect(
- options: &ConnectOptions,
- ) -> impl std::future::Future> + Send + '_
- {
- async {
- let (mut reader, mut writer) = TcpReader::new_tcp(options).await?;
-
- let mut buf_out = BytesMut::new();
-
- create_connect_from_options(options).write(&mut buf_out)?;
-
- writer.write_buffer(&mut buf_out).await?;
-
- let packet = reader.read().await?;
- if let Packet::ConnAck(con) = packet {
- if con.reason_code == ConnAckReasonCode::Success {
- Ok((reader, writer, Packet::ConnAck(con)))
- } else {
- Err(ConnectionError::ConnectionRefused(con.reason_code))
- }
- } else {
- Err(ConnectionError::NotConnAck(packet))
- }
- }
- }
-
- async fn read(&mut self) -> Result {
- Ok(self.read().await?)
- }
-
- async fn read_direct(
- &mut self,
- incoming_packet_sender: &async_channel::Sender,
- ) -> Result {
- let mut read_packets = 0;
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
- let read_packet = Packet::read(header, buf.into())?;
- tracing::trace!("Read packet from network {}", read_packet);
- let disconnect = read_packet.packet_type() == PacketType::Disconnect;
- incoming_packet_sender.send(read_packet).await?;
- if disconnect {
- return Ok(true);
- }
- read_packets += 1;
- if read_packets >= 10 {
- return Ok(false);
- }
- }
- }
-}
-
-pub struct TcpWriter {
- writehalf: WriteHalf,
-
- buffer: BytesMut,
-}
-
-impl TcpWriter {
- pub fn new(writehalf: WriteHalf) -> Self {
- Self {
- writehalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- }
- }
-}
-
-impl AsyncMqttNetworkWrite for TcpWriter {
- async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> {
- if buffer.is_empty() {
- return Ok(());
- }
-
- self.writehalf.write_all(&buffer[..]).await?;
- buffer.clear();
- Ok(())
- }
-
- async fn write(&mut self, outgoing: &Receiver) -> Result {
- let mut disconnect = false;
-
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- }
- trace!("Sending packet {}", packet);
-
- self.writehalf.write_all(&self.buffer[..]).await?;
- self.writehalf.flush().await?;
- self.buffer.clear();
- Ok(disconnect)
- }
-}
diff --git a/src/connections/tokio_rustls.rs b/src/connections/tokio_rustls.rs
deleted file mode 100644
index a09bc58..0000000
--- a/src/connections/tokio_rustls.rs
+++ /dev/null
@@ -1,290 +0,0 @@
-use std::io::{self, Error, ErrorKind};
-
-use async_channel::Receiver;
-
-use bytes::{Buf, BytesMut};
-use rustls::ServerName;
-
-use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
-use tokio::net::TcpStream;
-use tokio_rustls::client::TlsStream;
-use tokio_rustls::TlsConnector;
-use tracing::trace;
-
-use crate::error::TlsError;
-use crate::{
- connect_options::ConnectOptions, connections::create_connect_from_options,
- error::ConnectionError, network::Incoming,
-};
-use crate::{
- connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite},
- packets::{
- error::ReadBytes,
- {FixedHeader, Packet, PacketType},
- reason_codes::ConnAckReasonCode,
- },
-};
-
-use super::transport::{RustlsConfig, TlsConfig};
-use super::util::simple_rust_tls;
-
-#[derive(Debug)]
-pub struct TlsReader {
- readhalf: ReadHalf>,
-
- /// Buffered reads
- buffer: BytesMut,
-}
-
-impl TlsReader {
- pub async fn new(options: &ConnectOptions) -> Result<(TlsReader, TlsWriter), TlsError> {
- if let Some(tls_config) = &options.tls_config {
- let arc_tls_config = match tls_config {
- TlsConfig::Rustls(RustlsConfig::Simple {
- ca,
- alpn,
- client_auth,
- }) => simple_rust_tls(ca.clone(), alpn.clone(), client_auth.clone())?,
- TlsConfig::Rustls(RustlsConfig::Rustls(config)) => config.clone(),
- };
-
- let domain = ServerName::try_from(options.address.as_str())?;
- let connector = TlsConnector::from(arc_tls_config);
-
- let stream = TcpStream::connect((options.address.as_str(), options.port)).await?;
- let connection = connector.connect(domain, stream).await?;
-
- let (readhalf, writehalf) = tokio::io::split(connection);
-
- let reader = Self {
- readhalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- };
- let writer = TlsWriter::new(writehalf);
-
- Ok((reader, writer))
- } else {
- Err(TlsError::NoTlsConfig)
- }
- }
-
- pub async fn read(&mut self) -> io::Result {
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
-
- return Packet::read(header, buf.into())
- .map_err(|err| Error::new(ErrorKind::InvalidData, err));
- }
- }
-
- /// Reads more than 'required' bytes to frame a packet into self.read buffer
- async fn read_bytes(&mut self, required: usize) -> io::Result {
- let mut total_read = 0;
- loop {
- let read = self.readhalf.read_buf(&mut self.buffer).await?;
-
- if read == 0 {
- return if self.buffer.is_empty() {
- Err(io::Error::new(
- io::ErrorKind::ConnectionAborted,
- "Connection closed by peer",
- ))
- } else {
- Err(io::Error::new(
- io::ErrorKind::ConnectionReset,
- "Connection reset by peer",
- ))
- };
- }
-
- total_read += read;
- if total_read >= required {
- return Ok(total_read);
- }
- }
- }
-}
-
-impl AsyncMqttNetworkRead for TlsReader {
- type W = TlsWriter;
-
- fn connect(
- options: &ConnectOptions,
- ) -> impl std::future::Future> + Send + '_
- {
- async {
- let (mut reader, mut writer) = TlsReader::new(options).await?;
-
- let mut buf_out = BytesMut::new();
-
- create_connect_from_options(options).write(&mut buf_out)?;
-
- writer.write_buffer(&mut buf_out).await?;
-
- if !buf_out.is_empty() {
- panic!("Should be empty");
- }
-
- let packet = reader.read().await?;
- if let Packet::ConnAck(con) = packet {
- if con.reason_code == ConnAckReasonCode::Success {
- Ok((reader, writer, Packet::ConnAck(con)))
- } else {
- Err(ConnectionError::ConnectionRefused(con.reason_code))
- }
- } else {
- Err(ConnectionError::NotConnAck(packet))
- }
- }
- }
-
- async fn read(&mut self) -> Result {
- Ok(self.read().await?)
- }
-
- async fn read_direct(
- &mut self,
- incoming_packet_sender: &async_channel::Sender,
- ) -> Result {
- let mut read_packets = 0;
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
- let read_packet = Packet::read(header, buf.into())?;
- tracing::trace!("Read packet from network {}", read_packet);
- let disconnect = read_packet.packet_type() == PacketType::Disconnect;
- incoming_packet_sender.send(read_packet).await?;
- if disconnect {
- return Ok(true);
- }
- read_packets += 1;
- if read_packets >= 10 {
- return Ok(false);
- }
- }
- }
-}
-
-#[derive(Debug)]
-pub struct TlsWriter {
- writehalf: WriteHalf>,
- buffer: BytesMut,
-}
-
-impl TlsWriter {
- pub fn new(writehalf: WriteHalf>) -> Self {
- Self {
- writehalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- }
- }
-}
-
-impl AsyncMqttNetworkWrite for TlsWriter {
- async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> {
- if buffer.is_empty() {
- return Ok(());
- }
-
- self.writehalf.write_all(&buffer[..]).await?;
- self.writehalf.flush().await?;
- buffer.clear();
- Ok(())
- }
-
- async fn write(&mut self, outgoing: &Receiver) -> Result {
- let mut disconnect = false;
-
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- }
-
- while !outgoing.is_empty() && !disconnect {
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- break;
- }
- trace!("Going to write packet to network: {:?}", packet);
- }
-
- self.writehalf.write_all(&self.buffer[..]).await?;
- self.writehalf.flush().await?;
- self.buffer.clear();
- Ok(disconnect)
- }
-}
-
-#[cfg(test)]
-mod test {
- use crate::{
- connect_options::ConnectOptions,
- connections::{
- transport::{RustlsConfig, TlsConfig},
- AsyncMqttNetworkRead,
- },
- packets::{
- {Packet, PacketType},
- reason_codes::ConnAckReasonCode,
- },
- tests::resources::EMQX_CERT,
- };
-
- use super::TlsReader;
-
- async fn connect_emqx_test() {
- let config = TlsConfig::Rustls(RustlsConfig::Simple {
- ca: EMQX_CERT.to_vec(),
- alpn: None,
- client_auth: None,
- });
-
- let opt = ConnectOptions::new_with_tls_config(
- "broker.emqx.io".to_string(),
- 8883,
- "test123123".to_string(),
- Some(config),
- );
-
- let (_, _, packet) = TlsReader::connect(&opt).await.unwrap();
-
- assert_eq!(PacketType::ConnAck, packet.packet_type());
- if let Packet::ConnAck(conn) = packet {
- assert_eq!(ConnAckReasonCode::Success, conn.reason_code);
- }
- }
-}
diff --git a/src/connections/tokio_stream.rs b/src/connections/tokio_stream.rs
new file mode 100644
index 0000000..954e87c
--- /dev/null
+++ b/src/connections/tokio_stream.rs
@@ -0,0 +1,174 @@
+use std::io::{self, Error, ErrorKind};
+
+use bytes::{Buf, BytesMut};
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
+
+use tracing::trace;
+
+use crate::packets::{
+ error::ReadBytes,
+ reason_codes::ConnAckReasonCode,
+ {FixedHeader, Packet, PacketType},
+};
+use crate::{
+ connect_options::ConnectOptions, connections::create_connect_from_options,
+ error::ConnectionError,
+};
+
+#[derive(Debug)]
+pub struct TokioStream {
+ pub stream: S,
+
+ /// Input buffer
+ const_buffer: [u8; 1000],
+ /// Buffered reads
+ buffer: BytesMut,
+}
+
+impl TokioStream
+where
+ S: AsyncRead + AsyncWrite + Sized + Unpin,
+{
+ pub async fn connect(
+ options: &ConnectOptions,
+ stream: S,
+ ) -> Result<(Self, Packet), ConnectionError> {
+ let mut s = Self {
+ stream,
+ const_buffer: [0; 1000],
+ buffer: BytesMut::new(),
+ };
+
+ let mut buf_out = BytesMut::new();
+
+ create_connect_from_options(options).write(&mut buf_out)?;
+
+ s.write_buffer(&mut buf_out).await?;
+
+ let packet = s.read().await?;
+ if let Packet::ConnAck(con) = packet {
+ if con.reason_code == ConnAckReasonCode::Success {
+ Ok((s, Packet::ConnAck(con)))
+ }
+ else {
+ Err(ConnectionError::ConnectionRefused(con.reason_code))
+ }
+ }
+ else {
+ Err(ConnectionError::NotConnAck(packet))
+ }
+ }
+
+ pub async fn parse_messages(
+ &mut self,
+ incoming_packet_sender: &async_channel::Sender,
+ ) -> Result, ReadBytes> {
+ let mut ret_packet_type = None;
+ loop {
+ if self.buffer.is_empty() {
+ return Ok(ret_packet_type);
+ }
+ let (header, header_length) = FixedHeader::read_fixed_header(self.buffer.iter())?;
+
+ if header.remaining_length > self.buffer.len() {
+ return Err(ReadBytes::InsufficientBytes(
+ header.remaining_length - self.buffer.len(),
+ ));
+ }
+
+ self.buffer.advance(header_length);
+
+ let buf = self.buffer.split_to(header.remaining_length);
+ let read_packet = Packet::read(header, buf.into())?;
+ tracing::trace!("Read packet from network {}", read_packet);
+ let packet_type = read_packet.packet_type();
+ incoming_packet_sender.send(read_packet).await?;
+
+ match packet_type {
+ PacketType::Disconnect => return Ok(Some(PacketType::Disconnect)),
+ PacketType::PingResp => return Ok(Some(PacketType::PingResp)),
+ packet_type => ret_packet_type = Some(packet_type),
+ }
+ }
+ }
+
+ pub async fn read(&mut self) -> io::Result {
+ loop {
+ let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
+ Ok(header) => header,
+ Err(ReadBytes::InsufficientBytes(required_len)) => {
+ self.read_required_bytes(required_len).await?;
+ continue;
+ }
+ Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)),
+ };
+
+ self.buffer.advance(header_length);
+
+ if header.remaining_length > self.buffer.len() {
+ self.read_required_bytes(header.remaining_length - self.buffer.len())
+ .await?;
+ }
+
+ let buf = self.buffer.split_to(header.remaining_length);
+
+ return Packet::read(header, buf.into())
+ .map_err(|err| Error::new(ErrorKind::InvalidData, err));
+ }
+ }
+
+ pub async fn read_bytes(&mut self) -> io::Result {
+ let read = self.stream.read(&mut self.const_buffer).await?;
+ if 0 == read {
+ if self.buffer.is_empty() {
+ Err(io::Error::new(
+ io::ErrorKind::ConnectionAborted,
+ "Connection closed by peer",
+ ))
+ }
+ else {
+ Err(io::Error::new(
+ io::ErrorKind::ConnectionReset,
+ "Connection reset by peer",
+ ))
+ }
+ }
+ else {
+ self.buffer.extend_from_slice(&self.const_buffer[0..read]);
+ Ok(read)
+ }
+ }
+
+ /// Reads more than 'required' bytes to frame a packet into self.read buffer
+ pub async fn read_required_bytes(&mut self, required: usize) -> io::Result {
+ let mut total_read = 0;
+
+ loop {
+ let read = self.read_bytes().await?;
+ total_read += read;
+ if total_read >= required {
+ return Ok(total_read);
+ }
+ }
+ }
+
+ pub async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> {
+ if buffer.is_empty() {
+ return Ok(());
+ }
+
+ self.stream.write_all(&buffer[..]).await?;
+ buffer.clear();
+ Ok(())
+ }
+
+ pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> {
+ packet.write(&mut self.buffer)?;
+ trace!("Sending packet {}", packet);
+
+ self.stream.write_all(&self.buffer[..]).await?;
+ self.stream.flush().await?;
+ self.buffer.clear();
+ Ok(())
+ }
+}
diff --git a/src/connections/tokio_tcp.rs b/src/connections/tokio_tcp.rs
deleted file mode 100644
index 5a90ba5..0000000
--- a/src/connections/tokio_tcp.rs
+++ /dev/null
@@ -1,223 +0,0 @@
-use std::io::{self, Error, ErrorKind};
-
-use async_channel::Receiver;
-use bytes::{Buf, BytesMut};
-// #[cfg(feature = "smol")]
-// use smol::{
-// io::{AsyncReadExt, AsyncWriteExt},
-// net::TcpStream,
-// };
-use tokio::{io::AsyncReadExt, net::TcpStream};
-use tokio::{
- io::AsyncWriteExt,
- net::tcp::{OwnedReadHalf, OwnedWriteHalf},
-};
-use tracing::trace;
-
-use crate::packets::{
- error::ReadBytes,
- {FixedHeader, Packet, PacketType},
- reason_codes::ConnAckReasonCode,
-};
-use crate::{
- connect_options::ConnectOptions,
- connections::{create_connect_from_options, AsyncMqttNetworkWrite},
- error::ConnectionError,
- network::Incoming,
-};
-
-use super::AsyncMqttNetworkRead;
-
-#[derive(Debug)]
-pub struct TcpReader {
- readhalf: OwnedReadHalf,
-
- /// Buffered reads
- buffer: BytesMut,
-}
-
-impl TcpReader {
- pub async fn new_tcp(
- options: &ConnectOptions,
- ) -> Result<(TcpReader, TcpWriter), ConnectionError> {
- let (readhalf, writehalf) = TcpStream::connect((options.address.clone(), options.port))
- .await?
- .into_split();
- let reader = TcpReader {
- readhalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- // max_incoming_size: u32::MAX as usize,
- };
- let writer = TcpWriter::new(writehalf);
- Ok((reader, writer))
- }
-
- pub async fn read(&mut self) -> io::Result {
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
-
- return Packet::read(header, buf.into())
- .map_err(|err| Error::new(ErrorKind::InvalidData, err));
- }
- }
-
- /// Reads more than 'required' bytes to frame a packet into self.read buffer
- async fn read_bytes(&mut self, required: usize) -> io::Result {
- let mut total_read = 0;
-
- loop {
- #[cfg(feature = "tokio")]
- let read = self.readhalf.read_buf(&mut self.buffer).await?;
- // #[cfg(feature = "smol")]
- // let read = self.connection.read(&mut self.buffer).await?;
- if 0 == read {
- return if self.buffer.is_empty() {
- Err(io::Error::new(
- io::ErrorKind::ConnectionAborted,
- "Connection closed by peer",
- ))
- } else {
- Err(io::Error::new(
- io::ErrorKind::ConnectionReset,
- "Connection reset by peer",
- ))
- };
- }
-
- total_read += read;
- if total_read >= required {
- return Ok(total_read);
- }
- }
- }
-}
-
-impl AsyncMqttNetworkRead for TcpReader {
- type W = TcpWriter;
-
- fn connect(
- options: &ConnectOptions,
- ) -> impl std::future::Future> + Send + '_
- {
- async {
- let (mut reader, mut writer) = TcpReader::new_tcp(options).await?;
-
- let mut buf_out = BytesMut::new();
-
- create_connect_from_options(options).write(&mut buf_out)?;
-
- writer.write_buffer(&mut buf_out).await?;
-
- let packet = reader.read().await?;
- if let Packet::ConnAck(con) = packet {
- if con.reason_code == ConnAckReasonCode::Success {
- Ok((reader, writer, Packet::ConnAck(con)))
- } else {
- Err(ConnectionError::ConnectionRefused(con.reason_code))
- }
- } else {
- Err(ConnectionError::NotConnAck(packet))
- }
- }
- }
-
- async fn read(&mut self) -> Result {
- Ok(self.read().await?)
- }
-
- async fn read_direct(
- &mut self,
- incoming_packet_sender: &async_channel::Sender,
- ) -> Result {
- let mut read_packets = 0;
- loop {
- let (header, header_length) = match FixedHeader::read_fixed_header(self.buffer.iter()) {
- Ok(header) => header,
- Err(ReadBytes::InsufficientBytes(required_len)) => {
- self.read_bytes(required_len).await?;
- continue;
- }
- Err(ReadBytes::Err(err)) => return Err(ConnectionError::DeserializationError(err)),
- };
-
- self.buffer.advance(header_length);
-
- if header.remaining_length > self.buffer.len() {
- self.read_bytes(header.remaining_length - self.buffer.len())
- .await?;
- }
-
- let buf = self.buffer.split_to(header.remaining_length);
- let read_packet = Packet::read(header, buf.into())?;
- tracing::trace!("Read packet from network {}", read_packet);
- let disconnect = read_packet.packet_type() == PacketType::Disconnect;
- incoming_packet_sender.send(read_packet).await?;
- if disconnect {
- return Ok(true);
- }
- read_packets += 1;
- if read_packets >= 10 {
- return Ok(false);
- }
- }
- }
-}
-
-pub struct TcpWriter {
- writehalf: OwnedWriteHalf,
-
- buffer: BytesMut,
-}
-
-impl TcpWriter {
- pub fn new(writehalf: OwnedWriteHalf) -> Self {
- Self {
- writehalf,
- buffer: BytesMut::with_capacity(20 * 1024),
- }
- }
-}
-
-impl AsyncMqttNetworkWrite for TcpWriter {
- async fn write_buffer(&mut self, buffer: &mut BytesMut) -> Result<(), ConnectionError> {
- if buffer.is_empty() {
- return Ok(());
- }
-
- self.writehalf.write_all(&buffer[..]).await?;
- buffer.clear();
- Ok(())
- }
-
- async fn write(&mut self, outgoing: &Receiver) -> Result {
- let mut disconnect = false;
-
- let packet = outgoing.recv().await?;
- packet.write(&mut self.buffer)?;
- if packet.packet_type() == PacketType::Disconnect {
- disconnect = true;
- }
- trace!("Sending packet {}", packet);
-
- self.writehalf.write_all(&self.buffer[..]).await?;
- self.writehalf.flush().await?;
- self.buffer.clear();
- Ok(disconnect)
- }
-}
diff --git a/src/connections/transport.rs b/src/connections/transport.rs
deleted file mode 100644
index 9101622..0000000
--- a/src/connections/transport.rs
+++ /dev/null
@@ -1,31 +0,0 @@
-use rustls::ClientConfig;
-use std::sync::Arc;
-
-#[derive(Debug, Clone)]
-pub enum TlsConfig {
- #[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))]
- Rustls(RustlsConfig),
- #[cfg(feature = "native-tls")]
- Native {
- ca: Vec,
- der: Vec,
- password: String,
- },
-}
-
-#[cfg(any(feature = "smol-rustls", feature = "tokio-rustls"))]
-#[derive(Debug, Clone)]
-pub enum RustlsConfig {
- Simple {
- ca: Vec,
- alpn: Option>>,
- client_auth: Option<(Vec, PrivateKey)>,
- },
- Rustls(Arc),
-}
-
-#[derive(Debug, Clone)]
-pub enum PrivateKey {
- RSA(Vec),
- ECC(Vec),
-}
diff --git a/src/connections/util.rs b/src/connections/util.rs
deleted file mode 100644
index 26adf12..0000000
--- a/src/connections/util.rs
+++ /dev/null
@@ -1,74 +0,0 @@
-use std::{
- io::{BufReader, Cursor},
- sync::Arc,
-};
-
-use rustls::{Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore};
-
-use super::transport::PrivateKey;
-use crate::error::TlsError;
-
-// pub fn native() {}
-
-pub fn simple_rust_tls(
- ca: Vec,
- alpn: Option>>,
- client_auth: Option<(Vec, PrivateKey)>,
-) -> Result, TlsError> {
- let mut root_cert_store = RootCertStore::empty();
-
- let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca)))?;
-
- let trust_anchors = ca_certs.iter().map_while(|cert| {
- if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) {
- Some(OwnedTrustAnchor::from_subject_spki_name_constraints(
- ta.subject,
- ta.spki,
- ta.name_constraints,
- ))
- } else {
- None
- }
- });
- root_cert_store.add_server_trust_anchors(trust_anchors);
-
- if root_cert_store.is_empty() {
- return Err(TlsError::NoValidRootCertInChain);
- }
-
- let config = ClientConfig::builder()
- .with_safe_defaults()
- .with_root_certificates(root_cert_store);
-
- let mut config = match client_auth {
- Some((client_cert_info, client_private_info)) => {
- let read_private_keys = match client_private_info {
- PrivateKey::RSA(rsa) => {
- rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa)))
- }
- PrivateKey::ECC(ecc) => {
- rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc)))
- }
- }
- .map_err(|_| TlsError::NoValidPrivateKey)?;
-
- let key = read_private_keys
- .into_iter()
- .next()
- .ok_or(TlsError::NoValidPrivateKey)?;
-
- let client_certs =
- rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info)))?;
- let client_cert_chain = client_certs.into_iter().map(Certificate).collect();
-
- config.with_single_cert(client_cert_chain, rustls::PrivateKey(key))?
- }
- None => config.with_no_client_auth(),
- };
-
- if let Some(alpn) = alpn {
- config.alpn_protocols.extend(alpn)
- }
-
- Ok(Arc::new(config))
-}
diff --git a/src/error.rs b/src/error.rs
index dd93d94..296dce3 100644
--- a/src/error.rs
+++ b/src/error.rs
@@ -2,13 +2,10 @@ use std::io;
use async_channel::{RecvError, SendError};
-use crate::{
- packets::{
- error::{DeserializeError, SerializeError},
- {Packet, PacketType},
- reason_codes::ConnAckReasonCode,
- },
- util::timeout::Timeout,
+use crate::packets::{
+ error::{DeserializeError, ReadBytes, SerializeError},
+ reason_codes::ConnAckReasonCode,
+ {Packet, PacketType},
};
#[derive(Debug, Clone, thiserror::Error)]
@@ -44,14 +41,8 @@ pub enum ClientError {
/// Critical errors during eventloop polling
#[derive(Debug, thiserror::Error)]
pub enum ConnectionError {
- // #[error("Mqtt state: {0}")]
- // MqttState(#[from] StateError),
- #[error("Connect timeout")]
- Timeout(#[from] Timeout),
-
- #[cfg(feature = "use-rustls")]
- #[error("TLS: {0}")]
- Tls(#[from] tls::Error),
+ #[error("No network connection")]
+ NoNetwork,
#[error("No incoming packet handler available: {0}")]
NoIncomingPacketHandler(#[from] SendError),
@@ -76,9 +67,27 @@ pub enum ConnectionError {
#[error("Requests done")]
RequestsDone,
+}
+
+impl From> for ReadBytes {
+ fn from(value: ReadBytes) -> Self {
+ match value {
+ ReadBytes::Err(err) => ReadBytes::Err(err.into()),
+ ReadBytes::InsufficientBytes(id) => ReadBytes::InsufficientBytes(id),
+ }
+ }
+}
+
+impl From for ReadBytes {
+ fn from(value: DeserializeError) -> Self {
+ ReadBytes::Err(value.into())
+ }
+}
- #[error("TLS Error")]
- TLS(#[from] TlsError),
+impl From> for ReadBytes {
+ fn from(value: SendError) -> Self {
+ ReadBytes::Err(value.into())
+ }
}
#[derive(Debug, thiserror::Error)]
diff --git a/src/event_handler.rs b/src/event_handler.rs
index 5f81ad5..70035d5 100644
--- a/src/event_handler.rs
+++ b/src/event_handler.rs
@@ -1,36 +1,31 @@
use crate::connect_options::ConnectOptions;
use crate::error::MqttError;
+use crate::packets::reason_codes::{PubAckReasonCode, PubRecReasonCode};
use crate::packets::Disconnect;
-use crate::packets::{Packet, PacketType};
-use crate::packets::{PubAck, PubAckProperties};
use crate::packets::PubComp;
-use crate::packets::Publish;
use crate::packets::PubRec;
use crate::packets::PubRel;
-use crate::packets::reason_codes::{PubAckReasonCode, PubRecReasonCode};
+use crate::packets::Publish;
+use crate::packets::QoS;
use crate::packets::SubAck;
use crate::packets::Subscribe;
use crate::packets::UnsubAck;
use crate::packets::Unsubscribe;
-use crate::packets::QoS;
+use crate::packets::{Packet, PacketType};
+use crate::packets::{PubAck, PubAckProperties};
use crate::state::State;
+use crate::{AsyncEventHandler, AsyncEventHandlerMut, HandlerStatus};
use futures::FutureExt;
use async_channel::{Receiver, Sender};
-use async_mutex::Mutex;
-use tracing::{error, debug};
+use tracing::error;
-use std::future::Future;
-use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
-use std::time::Instant;
+#[cfg(test)]
+use tracing::debug;
/// Eventloop with all the state of a connection
pub struct EventHandlerTask {
- /// Options of the current mqtt connection
- // options: ConnectOptions,
- /// Current state of the connection
state: State,
network_receiver: Receiver,
@@ -39,11 +34,6 @@ pub struct EventHandlerTask {
client_to_handler_r: Receiver,
- last_network_action: Arc>,
- atomic_waiting_for_pingresp: AtomicBool,
- waiting_for_pingresp: bool,
- keep_alive_s: u64,
- atomic_disconnect: AtomicBool,
disconnect: bool,
}
@@ -57,7 +47,6 @@ impl EventHandlerTask {
network_receiver: Receiver,
network_sender: Sender,
client_to_handler_r: Receiver,
- last_network_action: Arc>,
) -> (Self, Receiver) {
let (state, packet_id_channel) = State::new(options.receive_maximum());
@@ -69,90 +58,84 @@ impl EventHandlerTask {
client_to_handler_r,
- last_network_action,
-
- atomic_waiting_for_pingresp: AtomicBool::new(false),
- waiting_for_pingresp: false,
-
- keep_alive_s: options.keep_alive_interval_s,
-
- atomic_disconnect: AtomicBool::new(false),
disconnect: false,
};
(task, packet_id_channel)
}
-
- pub fn sync_handle(
- &self,
- handler: &mut H,
- ) -> Result<(), MqttError> {
-
- match self.network_receiver.try_recv() {
- Ok(event) => {
- handler.handle(&event);
- },
- Err(err) => {
- if err.is_closed(){
- return Err(MqttError::IncomingNetworkChannelClosed)
+ // pub fn sync_handle(&self, handler: &mut H) -> Result<(), MqttError> {
+ // match self.network_receiver.try_recv() {
+ // Ok(event) => {
+ // handler.handle(&event);
+ // }
+ // Err(err) => {
+ // if err.is_closed() {
+ // return Err(MqttError::IncomingNetworkChannelClosed);
+ // }
+ // }
+ // }
+ // match self.client_to_handler_r.try_recv() {
+ // Ok(_) => {}
+ // Err(err) => {
+ // if err.is_closed() {
+ // return Err(MqttError::IncomingNetworkChannelClosed);
+ // }
+ // }
+ // }
+ // Ok(())
+ // }
+
+ pub async fn handle(&mut self, handler: &H) -> Result
+ where
+ H: AsyncEventHandler, {
+ futures::select! {
+ incoming = self.network_receiver.recv().fuse() => {
+ match incoming {
+ Ok(event) => {
+ // debug!("Event Handler, handling incoming packet: {}", event);
+ handler.handle(&event).await;
+ self.handle_incoming_packet(event).await?;
+ }
+ Err(_) => return Err(MqttError::IncomingNetworkChannelClosed),
}
- },
- }
- match self.client_to_handler_r.try_recv() {
- Ok(_) => {
-
- },
- Err(err) => {
- if err.is_closed(){
- return Err(MqttError::IncomingNetworkChannelClosed)
+ if self.disconnect {
+ self.disconnect = true;
+ return Ok(HandlerStatus::IncomingDisconnect);
}
},
+ outgoing = self.client_to_handler_r.recv().fuse() => {
+ match outgoing {
+ Ok(event) => {
+ // debug!("Event Handler, handling outgoing packet: {}", event);
+ self.handle_outgoing_packet(event).await?
+ }
+ Err(_) => return Err(MqttError::ClientChannelClosed),
+ }
+ if self.disconnect {
+ self.disconnect = true;
+ return Ok(HandlerStatus::OutgoingDisconnect);
+ }
+ }
}
- Ok(())
- // let keepalive = async {
- // let initial_keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0);
- // let mut keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0);
- // loop {
- // warn!("Awaiting PING sleep");
- // #[cfg(feature = "tokio")]
- // tokio::time::sleep(keep_alive_duration).await;
-
- // if (!self.waiting_for_pingresp.load(Ordering::Acquire))
- // && self.last_network_action.lock().await.elapsed()
- // >= initial_keep_alive_duration
- // {
- // self.network_sender.send(Packet::PingReq).await?;
- // self.waiting_for_pingresp.store(true, Ordering::Release);
- // keep_alive_duration = initial_keep_alive_duration;
- // } else {
- // keep_alive_duration = initial_keep_alive_duration
- // - self.last_network_action.lock().await.elapsed();
- // }
- // if self.disconnect.load(Ordering::Acquire) {
- // return Ok::<(), MqttError>(());
- // }
- // }
- // };
+ Ok(HandlerStatus::Active)
}
-
- pub async fn handle(
- &mut self,
- handler: &mut H,
- ) -> Result<(), MqttError> {
+ pub async fn handle_mut(&mut self, handler: &mut H) -> Result
+ where
+ H: AsyncEventHandlerMut, {
futures::select! {
incoming = self.network_receiver.recv().fuse() => {
match incoming {
Ok(event) => {
// debug!("Event Handler, handling incoming packet: {}", event);
handler.handle(&event).await;
- self.handle_incoming_packet(event).await?
+ self.handle_incoming_packet(event).await?;
}
Err(_) => return Err(MqttError::IncomingNetworkChannelClosed),
}
- if self.atomic_disconnect.load(Ordering::Acquire) {
- self.atomic_disconnect.store(false, Ordering::Release);
- return Ok::<(), MqttError>(());
+ if self.disconnect {
+ self.disconnect = true;
+ return Ok(HandlerStatus::IncomingDisconnect);
}
},
outgoing = self.client_to_handler_r.recv().fuse() => {
@@ -163,43 +146,16 @@ impl EventHandlerTask {
}
Err(_) => return Err(MqttError::ClientChannelClosed),
}
- if self.atomic_disconnect.load(Ordering::Acquire) {
- self.atomic_disconnect.store(false, Ordering::Release);
- return Ok::<(), MqttError>(());
+ if self.disconnect {
+ self.disconnect = true;
+ return Ok(HandlerStatus::OutgoingDisconnect);
}
}
}
- Ok(())
- // let keepalive = async {
- // let initial_keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0);
- // let mut keep_alive_duration = std::time::Duration::new(self.keep_alive_s, 0);
- // loop {
- // warn!("Awaiting PING sleep");
- // #[cfg(feature = "tokio")]
- // tokio::time::sleep(keep_alive_duration).await;
-
- // if (!self.waiting_for_pingresp.load(Ordering::Acquire))
- // && self.last_network_action.lock().await.elapsed()
- // >= initial_keep_alive_duration
- // {
- // self.network_sender.send(Packet::PingReq).await?;
- // self.waiting_for_pingresp.store(true, Ordering::Release);
- // keep_alive_duration = initial_keep_alive_duration;
- // } else {
- // keep_alive_duration = initial_keep_alive_duration
- // - self.last_network_action.lock().await.elapsed();
- // }
- // if self.disconnect.load(Ordering::Acquire) {
- // return Ok::<(), MqttError>(());
- // }
- // }
- // };
+ Ok(HandlerStatus::Active)
}
- async fn handle_incoming_packet(
- &mut self,
- packet: Packet,
- ) -> Result<(), MqttError> {
+ async fn handle_incoming_packet(&mut self, packet: Packet) -> Result<(), MqttError> {
match packet {
Packet::Publish(publish) => self.handle_incoming_publish(&publish).await?,
Packet::PubAck(puback) => self.handle_incoming_puback(&puback).await?,
@@ -208,11 +164,11 @@ impl EventHandlerTask {
Packet::PubComp(pubcomp) => self.handle_incoming_pubcomp(&pubcomp).await?,
Packet::SubAck(suback) => self.handle_incoming_suback(suback).await?,
Packet::UnsubAck(unsuback) => self.handle_incoming_unsuback(unsuback).await?,
- Packet::PingResp => self.handle_incoming_pingresp().await,
+ Packet::PingResp => (),
Packet::ConnAck(_) => (),
Packet::Disconnect(_) => {
- self.atomic_disconnect.store(true, Ordering::Release);
- },
+ self.disconnect = true;
+ }
a => unreachable!("Should not receive {}", a),
};
Ok(())
@@ -252,7 +208,12 @@ impl EventHandlerTask {
}
async fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result<(), MqttError> {
- if let Some(_) = self.state.outgoing_pub.remove(&puback.packet_identifier) {
+ if self
+ .state
+ .outgoing_pub
+ .remove(&puback.packet_identifier)
+ .is_some()
+ {
#[cfg(test)]
debug!(
"Publish {:?} has been acknowledged",
@@ -263,7 +224,8 @@ impl EventHandlerTask {
.mark_available(puback.packet_identifier)
.await?;
Ok(())
- } else {
+ }
+ else {
error!(
"Publish {:?} was not found, while receiving a PubAck for it",
puback.packet_identifier,
@@ -284,10 +246,7 @@ impl EventHandlerTask {
self.network_sender.send(Packet::PubRel(pubrel)).await?;
#[cfg(test)]
- debug!(
- "Publish {:?} has been PubReced",
- pubrec.packet_identifier
- );
+ debug!("Publish {:?} has been PubReced", pubrec.packet_identifier);
Ok(())
}
_ => Ok(()),
@@ -318,7 +277,8 @@ impl EventHandlerTask {
.mark_available(pubcomp.packet_identifier)
.await?;
Ok(())
- } else {
+ }
+ else {
error!(
"PubRel {} was not found, while receiving a PubComp for it",
pubcomp.packet_identifier,
@@ -330,18 +290,20 @@ impl EventHandlerTask {
}
}
- async fn handle_incoming_pingresp(&mut self) {
- self.atomic_waiting_for_pingresp.store(false, Ordering::Release);
- }
-
async fn handle_incoming_suback(&mut self, suback: SubAck) -> Result<(), MqttError> {
- if self.state.outgoing_sub.remove(&suback.packet_identifier).is_some() {
+ if self
+ .state
+ .outgoing_sub
+ .remove(&suback.packet_identifier)
+ .is_some()
+ {
self.state
.apkid
.mark_available(suback.packet_identifier)
.await?;
Ok(())
- } else {
+ }
+ else {
error!(
"Sub {} was not found, while receiving a SubAck for it",
suback.packet_identifier,
@@ -354,13 +316,19 @@ impl EventHandlerTask {
}
async fn handle_incoming_unsuback(&mut self, unsuback: UnsubAck) -> Result<(), MqttError> {
- if self.state.outgoing_unsub.remove(&unsuback.packet_identifier).is_some() {
+ if self
+ .state
+ .outgoing_unsub
+ .remove(&unsuback.packet_identifier)
+ .is_some()
+ {
self.state
.apkid
.mark_available(unsuback.packet_identifier)
.await?;
Ok(())
- } else {
+ }
+ else {
error!(
"Unsub {} was not found, while receiving a unsuback for it",
unsuback.packet_identifier,
@@ -391,8 +359,10 @@ impl EventHandlerTask {
self.network_sender
.send(Packet::Publish(publish.clone()))
.await?;
- if let Some(pub_collision) =
- self.state.outgoing_pub.insert(publish.packet_identifier.unwrap(), publish)
+ if let Some(pub_collision) = self
+ .state
+ .outgoing_pub
+ .insert(publish.packet_identifier.unwrap(), publish)
{
error!(
"Encountered a colliding packet ID ({:?}) in a publish QoS 1 packet",
@@ -404,8 +374,10 @@ impl EventHandlerTask {
self.network_sender
.send(Packet::Publish(publish.clone()))
.await?;
- if let Some(pub_collision) =
- self.state.outgoing_pub.insert(publish.packet_identifier.unwrap(), publish)
+ if let Some(pub_collision) = self
+ .state
+ .outgoing_pub
+ .insert(publish.packet_identifier.unwrap(), publish)
{
error!(
"Encountered a colliding packet ID ({:?}) in a publish QoS 2 packet",
@@ -418,7 +390,9 @@ impl EventHandlerTask {
}
async fn handle_outgoing_subscribe(&mut self, sub: Subscribe) -> Result<(), MqttError> {
- if self.state.outgoing_sub
+ if self
+ .state
+ .outgoing_sub
.insert(sub.packet_identifier, sub.clone())
.is_some()
{
@@ -426,14 +400,17 @@ impl EventHandlerTask {
"Encountered a colliding packet ID ({}) in a subscribe packet",
sub.packet_identifier,
)
- } else {
+ }
+ else {
self.network_sender.send(Packet::Subscribe(sub)).await?;
}
Ok(())
}
async fn handle_outgoing_unsubscribe(&mut self, unsub: Unsubscribe) -> Result<(), MqttError> {
- if self.state.outgoing_unsub
+ if self
+ .state
+ .outgoing_unsub
.insert(unsub.packet_identifier, unsub.clone())
.is_some()
{
@@ -441,13 +418,17 @@ impl EventHandlerTask {
"Encountered a colliding packet ID ({}) in a unsubscribe packet",
unsub.packet_identifier,
)
- } else {
+ }
+ else {
self.network_sender.send(Packet::Unsubscribe(unsub)).await?;
}
Ok(())
}
- async fn handle_outgoing_disconnect(&mut self, disconnect: Disconnect) -> Result<(), MqttError> {
+ async fn handle_outgoing_disconnect(
+ &mut self,
+ disconnect: Disconnect,
+ ) -> Result<(), MqttError> {
// self.atomic_disconnect.store(true, Ordering::Release);
self.disconnect = true;
self.network_sender
@@ -457,372 +438,356 @@ impl EventHandlerTask {
}
}
-pub trait AsyncEventHandler: Sized + Sync + 'static {
- fn handle<'a>(&'a mut self, event: &'a Packet) -> impl Future + Send + 'a;
-}
-
-pub trait EventHandler: Sized {
- fn handle<'a>(&'a mut self, event: &'a Packet);
-}
-
#[cfg(test)]
mod handler_tests {
- use std::{sync::Arc, time::Duration};
-
- use async_channel::{Receiver, Sender};
- use async_mutex::Mutex;
-
- use crate::{
- connect_options::ConnectOptions,
- event_handler::{AsyncEventHandler, EventHandlerTask},
- packets::{
- {Packet, PacketType},
- {PubComp, PubCompProperties},
- {PubRec, PubRecProperties},
- {PubRel, PubRelProperties},
- reason_codes::{
- PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode,
- },
- {SubAck, SubAckProperties},
- QoS,
- },
- tests::resources::test_packets::{
- create_disconnect_packet, create_puback_packet, create_publish_packet,
- create_subscribe_packet,
- },
- };
-
- pub struct Nop {}
- impl AsyncEventHandler for Nop {
- fn handle<'a>(
- &'a mut self,
- _event: &'a Packet,
- ) -> impl core::future::Future + Send + 'a {
- async move { () }
- }
- }
-
- fn handler() -> (
- EventHandlerTask,
- Receiver,
- Sender,
- Sender,
- ) {
- let opt = ConnectOptions::new("127.0.0.1".to_string(), 1883, "test123123".to_string());
-
- let (to_network_s, to_network_r) = async_channel::bounded(100);
- let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100);
- let (client_to_handler_s, client_to_handler_r) = async_channel::bounded(100);
-
- let (handler, _apkid) = EventHandlerTask::new(
- &opt,
- network_to_handler_r,
- to_network_s,
- client_to_handler_r,
- Arc::new(Mutex::new(
- std::time::Instant::now() + Duration::new(600, 0),
- )),
- );
- (
- handler,
- to_network_r,
- network_to_handler_s,
- client_to_handler_s,
- )
- }
-
- #[tokio::test(flavor = "multi_thread")]
- async fn outgoing_publish_qos_0() {
- let mut nop = Nop {};
-
- let (mut handler, to_network_r, _network_to_handler_s, client_to_handler_s) = handler();
-
- let handler_task = tokio::task::spawn(async move {
- let _ = loop{
- match handler.handle(&mut nop).await{
- Ok(_) => (),
- Err(_) => break,
- }
- };
- return handler;
- });
- let pub_packet = create_publish_packet(QoS::AtMostOnce, false, false, None);
-
- client_to_handler_s.send(pub_packet.clone()).await.unwrap();
-
- let packet = to_network_r.recv().await.unwrap();
+ // use std::{sync::Arc, time::Duration};
+
+ // use async_channel::{Receiver, Sender};
+ // use async_mutex::Mutex;
+
+ // use crate::{
+ // connect_options::ConnectOptions,
+ // event_handler::{EventHandlerTask},
+ // packets::{
+ // reason_codes::{
+ // PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode,
+ // },
+ // QoS, {Packet, PacketType}, {PubComp, PubCompProperties}, {PubRec, PubRecProperties},
+ // {PubRel, PubRelProperties}, {SubAck, SubAckProperties},
+ // },
+ // tests::resources::test_packets::{
+ // create_disconnect_packet, create_puback_packet, create_publish_packet,
+ // create_subscribe_packet,
+ // },
+ // };
+
+ // pub struct Nop {}
+ // // impl AsyncEventHandler for Nop {
+ // // fn handle<'a>(
+ // // &'a mut self,
+ // // _event: &'a Packet,
+ // // ) -> impl core::future::Future + Send + 'a {
+ // // async move { () }
+ // // }
+ // // }
+
+ // fn handler() -> (
+ // EventHandlerTask,
+ // Receiver,
+ // Sender,
+ // Sender,
+ // ) {
+ // let opt = ConnectOptions::new("test123123".to_string());
+
+ // let (to_network_s, to_network_r) = async_channel::bounded(100);
+ // let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100);
+ // let (client_to_handler_s, client_to_handler_r) = async_channel::bounded(100);
+
+ // let (handler, _apkid) = EventHandlerTask::new(
+ // &opt,
+ // network_to_handler_r,
+ // to_network_s,
+ // client_to_handler_r,
+ // );
+ // (
+ // handler,
+ // to_network_r,
+ // network_to_handler_s,
+ // client_to_handler_s,
+ // )
+ // }
+
+ // #[tokio::test(flavor = "multi_thread")]
+ // async fn outgoing_publish_qos_0() {
+ // let mut nop = Nop {};
+
+ // let (mut handler, to_network_r, _network_to_handler_s, client_to_handler_s) = handler();
+
+ // let handler_task = tokio::task::spawn(async move {
+ // let _ = loop {
+ // match handler.handle(&mut nop).await {
+ // Ok(_) => (),
+ // Err(_) => break,
+ // }
+ // };
+ // return handler;
+ // });
+ // let pub_packet = create_publish_packet(QoS::AtMostOnce, false, false, None);
+
+ // client_to_handler_s.send(pub_packet.clone()).await.unwrap();
+
+ // let packet = to_network_r.recv().await.unwrap();
+
+ // assert_eq!(packet, pub_packet);
+
+ // // If we drop the client to handler channel the handler will stop executing and we can inspect its internals.
+ // drop(client_to_handler_s);
+
+ // let handler = handler_task.await.unwrap();
+
+ // assert!(handler.state.incoming_pub.is_empty());
+ // assert!(handler.state.outgoing_pub.is_empty());
+ // assert!(handler.state.outgoing_rel.is_empty());
+ // assert!(handler.state.outgoing_sub.is_empty());
+ // }
+
+ // #[tokio::test(flavor = "multi_thread")]
+ // async fn outgoing_publish_qos_1() {
+ // pub struct TestPubQoS1 {
+ // stage: StagePubQoS1,
+ // }
+ // pub enum StagePubQoS1 {
+ // PubAck,
+ // Done,
+ // }
+ // impl TestPubQoS1 {
+ // fn new() -> Self {
+ // TestPubQoS1 {
+ // stage: StagePubQoS1::PubAck,
+ // }
+ // }
+ // }
+ // // impl AsyncEventHandler for TestPubQoS1 {
+ // // fn handle<'a>(
+ // // &'a mut self,
+ // // event: &'a Packet,
+ // // ) -> impl core::future::Future + Send + 'a {
+ // // async move {
+ // // match self.stage {
+ // // StagePubQoS1::PubAck => {
+ // // assert_eq!(event.packet_type(), PacketType::PubAck);
+ // // self.stage = StagePubQoS1::Done;
+ // // }
+ // // StagePubQoS1::Done => (),
+ // // }
+ // // }
+ // // }
+ // // }
+
+ // let mut nop = TestPubQoS1::new();
+
+ // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
+
+ // let handler_task = tokio::task::spawn(async move {
+ // // Ignore the error that this will return
+ // let _ = loop {
+ // match handler.handle(&mut nop).await {
+ // Ok(_) => (),
+ // Err(_) => break,
+ // }
+ // };
+ // return handler;
+ // });
+ // let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1));
+
+ // client_to_handler_s.send(pub_packet.clone()).await.unwrap();
+
+ // let publish = to_network_r.recv().await.unwrap();
+
+ // assert_eq!(pub_packet, publish);
+
+ // let puback = create_puback_packet(1);
+
+ // network_to_handler_s.send(puback).await.unwrap();
+
+ // tokio::time::sleep(Duration::new(5, 0)).await;
+
+ // // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals.
+ // drop(client_to_handler_s);
+ // drop(network_to_handler_s);
+
+ // let handler = handler_task.await.unwrap();
+
+ // assert!(handler.state.incoming_pub.is_empty());
+ // assert!(handler.state.outgoing_pub.is_empty());
+ // assert!(handler.state.outgoing_rel.is_empty());
+ // assert!(handler.state.outgoing_sub.is_empty());
+ // }
+
+ // #[tokio::test(flavor = "multi_thread")]
+ // async fn incoming_publish_qos_1() {
+ // let mut nop = Nop {};
+
+ // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
+
+ // let handler_task = tokio::task::spawn(async move {
+ // // Ignore the error that this will return
+ // let _ = loop {
+ // match handler.handle(&mut nop).await {
+ // Ok(_) => (),
+ // Err(_) => break,
+ // }
+ // };
+ // return handler;
+ // });
+ // let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1));
+
+ // network_to_handler_s.send(pub_packet.clone()).await.unwrap();
+
+ // let puback = to_network_r.recv().await.unwrap();
+
+ // assert_eq!(PacketType::PubAck, puback.packet_type());
+
+ // let expected_puback = create_puback_packet(1);
+
+ // assert_eq!(expected_puback, puback);
+
+ // // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals.
+ // drop(client_to_handler_s);
+ // drop(network_to_handler_s);
+
+ // let handler = handler_task.await.unwrap();
+
+ // assert!(handler.state.incoming_pub.is_empty());
+ // assert!(handler.state.outgoing_pub.is_empty());
+ // assert!(handler.state.outgoing_rel.is_empty());
+ // assert!(handler.state.outgoing_sub.is_empty());
+ // }
+
+ // #[tokio::test(flavor = "multi_thread")]
+ // async fn outgoing_publish_qos_2() {
+ // pub struct TestPubQoS2 {
+ // stage: StagePubQoS2,
+ // client_to_handler_s: Sender,
+ // }
+ // pub enum StagePubQoS2 {
+ // PubRec,
+ // PubComp,
+ // Done,
+ // }
+ // impl TestPubQoS2 {
+ // fn new(client_to_handler_s: Sender) -> Self {
+ // TestPubQoS2 {
+ // stage: StagePubQoS2::PubRec,
+ // client_to_handler_s,
+ // }
+ // }
+ // }
+ // // impl AsyncEventHandler for TestPubQoS2 {
+ // // fn handle<'a>(
+ // // &'a mut self,
+ // // event: &'a Packet,
+ // // ) -> impl core::future::Future + Send + 'a {
+ // // async move {
+ // // match self.stage {
+ // // StagePubQoS2::PubRec => {
+ // // assert_eq!(event.packet_type(), PacketType::PubRec);
+ // // self.stage = StagePubQoS2::PubComp;
+ // // }
+ // // StagePubQoS2::PubComp => {
+ // // assert_eq!(event.packet_type(), PacketType::PubComp);
+ // // self.stage = StagePubQoS2::Done;
+ // // self.client_to_handler_s
+ // // .send(create_disconnect_packet())
+ // // .await
+ // // .unwrap();
+ // // }
+ // // StagePubQoS2::Done => (),
+ // // }
+ // // }
+ // // }
+ // // }
+
+ // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
+
+ // let mut nop = TestPubQoS2::new(client_to_handler_s.clone());
+
+ // let handler_task = tokio::task::spawn(async move {
+ // let _ = loop {
+ // match handler.handle(&mut nop).await {
+ // Ok(_) => (),
+ // Err(_) => break,
+ // }
+ // };
+ // return handler;
+ // });
+ // let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1));
+
+ // client_to_handler_s.send(pub_packet.clone()).await.unwrap();
+
+ // let publish = to_network_r.recv().await.unwrap();
+
+ // assert_eq!(pub_packet, publish);
+
+ // let pubrec = Packet::PubRec(PubRec {
+ // packet_identifier: 1,
+ // reason_code: PubRecReasonCode::Success,
+ // properties: PubRecProperties::default(),
+ // });
+
+ // network_to_handler_s.send(pubrec).await.unwrap();
+
+ // let packet = to_network_r.recv().await.unwrap();
+
+ // let expected_pubrel = Packet::PubRel(PubRel {
+ // packet_identifier: 1,
+ // reason_code: PubRelReasonCode::Success,
+ // properties: PubRelProperties::default(),
+ // });
+
+ // assert_eq!(expected_pubrel, packet);
+
+ // let pubcomp = Packet::PubComp(PubComp {
+ // packet_identifier: 1,
+ // reason_code: PubCompReasonCode::Success,
+ // properties: PubCompProperties::default(),
+ // });
+
+ // network_to_handler_s.send(pubcomp).await.unwrap();
+
+ // drop(client_to_handler_s);
+ // drop(network_to_handler_s);
+
+ // let handler = handler_task.await.unwrap();
+
+ // assert!(handler.state.incoming_pub.is_empty());
+ // assert!(handler.state.outgoing_pub.is_empty());
+ // assert!(handler.state.outgoing_rel.is_empty());
+ // assert!(handler.state.outgoing_sub.is_empty());
+ // }
+
+ // #[tokio::test(flavor = "multi_thread")]
+ // async fn outgoing_subscribe() {
+ // let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
+
+ // let mut nop = Nop {};
+
+ // let handler_task = tokio::task::spawn(async move {
+ // // Ignore the error that this will return
+ // let _ = loop {
+ // match handler.handle(&mut nop).await {
+ // Ok(_) => (),
+ // Err(_) => break,
+ // }
+ // };
+ // return handler;
+ // });
+
+ // let sub_packet = create_subscribe_packet(1);
+
+ // client_to_handler_s.send(sub_packet.clone()).await.unwrap();
+
+ // let sub_result = to_network_r.recv().await.unwrap();
+
+ // assert_eq!(sub_packet, sub_result);
+
+ // let suback = Packet::SubAck(SubAck {
+ // packet_identifier: 1,
+ // reason_codes: vec![SubAckReasonCode::GrantedQoS0],
+ // properties: SubAckProperties::default(),
+ // });
- assert_eq!(packet, pub_packet);
+ // network_to_handler_s.send(suback).await.unwrap();
- // If we drop the client to handler channel the handler will stop executing and we can inspect its internals.
- drop(client_to_handler_s);
+ // tokio::time::sleep(Duration::new(2, 0)).await;
- let handler = handler_task.await.unwrap();
+ // drop(client_to_handler_s);
+ // drop(network_to_handler_s);
- assert!(handler.state.incoming_pub.is_empty());
- assert!(handler.state.outgoing_pub.is_empty());
- assert!(handler.state.outgoing_rel.is_empty());
- assert!(handler.state.outgoing_sub.is_empty());
- }
-
- #[tokio::test(flavor = "multi_thread")]
- async fn outgoing_publish_qos_1() {
- pub struct TestPubQoS1 {
- stage: StagePubQoS1,
- }
- pub enum StagePubQoS1 {
- PubAck,
- Done,
- }
- impl TestPubQoS1 {
- fn new() -> Self {
- TestPubQoS1 {
- stage: StagePubQoS1::PubAck,
- }
- }
- }
- impl AsyncEventHandler for TestPubQoS1 {
- fn handle<'a>(
- &'a mut self,
- event: &'a Packet,
- ) -> impl core::future::Future + Send + 'a {
- async move {
- match self.stage {
- StagePubQoS1::PubAck => {
- assert_eq!(event.packet_type(), PacketType::PubAck);
- self.stage = StagePubQoS1::Done;
- }
- StagePubQoS1::Done => (),
- }
- }
- }
- }
-
- let mut nop = TestPubQoS1::new();
-
- let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
-
- let handler_task = tokio::task::spawn(async move {
- // Ignore the error that this will return
- let _ = loop{
- match handler.handle(&mut nop).await{
- Ok(_) => (),
- Err(_) => break,
- }
- };
- return handler;
- });
- let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1));
-
- client_to_handler_s.send(pub_packet.clone()).await.unwrap();
-
- let publish = to_network_r.recv().await.unwrap();
-
- assert_eq!(pub_packet, publish);
-
- let puback = create_puback_packet(1);
-
- network_to_handler_s.send(puback).await.unwrap();
-
- tokio::time::sleep(Duration::new(5, 0)).await;
-
- // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals.
- drop(client_to_handler_s);
- drop(network_to_handler_s);
-
- let handler = handler_task.await.unwrap();
-
- assert!(handler.state.incoming_pub.is_empty());
- assert!(handler.state.outgoing_pub.is_empty());
- assert!(handler.state.outgoing_rel.is_empty());
- assert!(handler.state.outgoing_sub.is_empty());
- }
-
- #[tokio::test(flavor = "multi_thread")]
- async fn incoming_publish_qos_1(){
-
- let mut nop = Nop{};
-
- let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
-
- let handler_task = tokio::task::spawn(async move {
- // Ignore the error that this will return
- let _ = loop{
- match handler.handle(&mut nop).await{
- Ok(_) => (),
- Err(_) => break,
- }
- };
- return handler;
- });
- let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1));
-
- network_to_handler_s.send(pub_packet.clone()).await.unwrap();
-
- let puback = to_network_r.recv().await.unwrap();
-
- assert_eq!(PacketType::PubAck, puback.packet_type());
-
- let expected_puback = create_puback_packet(1);
-
- assert_eq!(expected_puback, puback);
-
- // If we drop the client_to_handler channel the handler will stop executing and we can inspect its internals.
- drop(client_to_handler_s);
- drop(network_to_handler_s);
-
- let handler = handler_task.await.unwrap();
-
- assert!(handler.state.incoming_pub.is_empty());
- assert!(handler.state.outgoing_pub.is_empty());
- assert!(handler.state.outgoing_rel.is_empty());
- assert!(handler.state.outgoing_sub.is_empty());
- }
-
- #[tokio::test(flavor = "multi_thread")]
- async fn outgoing_publish_qos_2() {
- pub struct TestPubQoS2 {
- stage: StagePubQoS2,
- client_to_handler_s: Sender,
- }
- pub enum StagePubQoS2 {
- PubRec,
- PubComp,
- Done,
- }
- impl TestPubQoS2 {
- fn new(client_to_handler_s: Sender) -> Self {
- TestPubQoS2 {
- stage: StagePubQoS2::PubRec,
- client_to_handler_s,
- }
- }
- }
- impl AsyncEventHandler for TestPubQoS2 {
- fn handle<'a>(
- &'a mut self,
- event: &'a Packet,
- ) -> impl core::future::Future + Send + 'a {
- async move {
- match self.stage {
- StagePubQoS2::PubRec => {
- assert_eq!(event.packet_type(), PacketType::PubRec);
- self.stage = StagePubQoS2::PubComp;
- }
- StagePubQoS2::PubComp => {
- assert_eq!(event.packet_type(), PacketType::PubComp);
- self.stage = StagePubQoS2::Done;
- self.client_to_handler_s
- .send(create_disconnect_packet())
- .await
- .unwrap();
- }
- StagePubQoS2::Done => (),
- }
- }
- }
- }
-
- let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
-
- let mut nop = TestPubQoS2::new(client_to_handler_s.clone());
-
- let handler_task = tokio::task::spawn(async move {
- let _ = loop{
- match handler.handle(&mut nop).await{
- Ok(_) => (),
- Err(_) => break,
- }
- };
- return handler;
- });
- let pub_packet = create_publish_packet(QoS::AtLeastOnce, false, false, Some(1));
-
- client_to_handler_s.send(pub_packet.clone()).await.unwrap();
-
- let publish = to_network_r.recv().await.unwrap();
-
- assert_eq!(pub_packet, publish);
-
- let pubrec = Packet::PubRec(PubRec {
- packet_identifier: 1,
- reason_code: PubRecReasonCode::Success,
- properties: PubRecProperties::default(),
- });
-
- network_to_handler_s.send(pubrec).await.unwrap();
-
- let packet = to_network_r.recv().await.unwrap();
-
- let expected_pubrel = Packet::PubRel(PubRel {
- packet_identifier: 1,
- reason_code: PubRelReasonCode::Success,
- properties: PubRelProperties::default(),
- });
-
- assert_eq!(expected_pubrel, packet);
-
- let pubcomp = Packet::PubComp(PubComp {
- packet_identifier: 1,
- reason_code: PubCompReasonCode::Success,
- properties: PubCompProperties::default(),
- });
-
- network_to_handler_s.send(pubcomp).await.unwrap();
-
- drop(client_to_handler_s);
- drop(network_to_handler_s);
-
- let handler = handler_task.await.unwrap();
-
- assert!(handler.state.incoming_pub.is_empty());
- assert!(handler.state.outgoing_pub.is_empty());
- assert!(handler.state.outgoing_rel.is_empty());
- assert!(handler.state.outgoing_sub.is_empty());
- }
-
- #[tokio::test(flavor = "multi_thread")]
- async fn outgoing_subscribe() {
- let (mut handler, to_network_r, network_to_handler_s, client_to_handler_s) = handler();
-
- let mut nop = Nop {};
-
- let handler_task = tokio::task::spawn(async move {
- // Ignore the error that this will return
- let _ = loop{
- match handler.handle(&mut nop).await{
- Ok(_) => (),
- Err(_) => break,
- }
- };
- return handler;
- });
-
- let sub_packet = create_subscribe_packet(1);
-
- client_to_handler_s.send(sub_packet.clone()).await.unwrap();
-
- let sub_result = to_network_r.recv().await.unwrap();
-
- assert_eq!(sub_packet, sub_result);
-
- let suback = Packet::SubAck(SubAck {
- packet_identifier: 1,
- reason_codes: vec![SubAckReasonCode::GrantedQoS0],
- properties: SubAckProperties::default(),
- });
-
- network_to_handler_s.send(suback).await.unwrap();
-
- tokio::time::sleep(Duration::new(2, 0)).await;
-
- drop(client_to_handler_s);
- drop(network_to_handler_s);
-
- let handler = handler_task.await.unwrap();
- assert!(handler.state.incoming_pub.is_empty());
- assert!(handler.state.outgoing_pub.is_empty());
- assert!(handler.state.outgoing_rel.is_empty());
- assert!(handler.state.outgoing_sub.is_empty());
- }
+ // let handler = handler_task.await.unwrap();
+ // assert!(handler.state.incoming_pub.is_empty());
+ // assert!(handler.state.outgoing_pub.is_empty());
+ // assert!(handler.state.outgoing_rel.is_empty());
+ // assert!(handler.state.outgoing_sub.is_empty());
+ // }
}
diff --git a/src/lib.rs b/src/lib.rs
index 917f47f..ba73a9e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,23 +1,27 @@
//! A pure rust MQTT client which strives to be as efficient as possible.
//! This crate strives to provide an ergonomic API and design that fits Rust.
-//!
+//!
//! There are three parts to the design of the MQTT client. The network, the event handler and the client.
-//!
+//!
//! The network - which simply reads and forms packets from the network.
//! The event handler - which makes sure that the MQTT protocol is followed.
//! By providing a custom handler during the internal handling, messages are handled before they are acked.
//! The client - which is used to send messages from different places.
-//!
+//!
//! To Do:
//! - Rebroadcast unacked packets.
-//! - Keep alive sending of PingReq and PingResp.
//! - Enforce size of outbound messages (e.g. Publish)
-//!
+//! - Sync API
+//!
//! A few questions still remain:
//! - This crate uses async channels to perform communication across its parts. Is there a better approach?
-//! These channels do allow the user to decouple this crate and the network in the future but comes at a cost of more copies
-//! - This crate provides network implementation which hinder sync and async agnosticism.
-//!
+//! These channels do allow the user to decouple the network, handlers, and clients very easily.
+//! - This crate provides network implementation which hinder syn, counter: 0c and async agnosticism.
+//! Would a true sansio implementation be better?
+//! At first this crate used custom async traits which are not stable (async_fn_in_trait).
+//! The current version allows the user to provide the appropriate stream.
+//! This also nicely relives us from having to deal with TLS configuration.
+//!
//! For the future it could be nice to be sync, async and runtime agnostic.
//! This can be achieved by decoupling the MQTT internals from the network communication.
//! The user could provide the received packets while this crate returns the response packets.
@@ -27,7 +31,7 @@
//!
//! Tokio example:
//! ----------------------------
-//! ```no_run
+//! ```ignore
//! let config = RustlsConfig::Simple {
//! ca: EMQX_CERT.to_vec(),
//! alpn: None,
@@ -35,85 +39,101 @@
//! };
//! let opt = ConnectOptions::new("broker.emqx.io".to_string(), 8883, "test123123".to_string());
//! let (mqtt_network, handler, client) = create_tokio_rustls(opt, config);
-//!
+//!
//! task::spawn(async move {
//! join!(mqtt_network.run(), handler.handle(/* Custom handler */));
//! });
-//!
+//!
//! for i in 0..10 {
//! client.publish("test", QoS::AtLeastOnce, false, b"test payload").await.unwrap();
//! time::sleep(Duration::from_millis(100)).await;
//! }
//! ```
-//!
+//!
//! Smol example:
-//!
//! ```
+//! use mqrstt::{
+//! client::AsyncClient,
+//! connect_options::ConnectOptions,
+//! new_smol,
+//! packets::{self, Packet},
+//! AsyncEventHandlerMut, HandlerStatus, NetworkStatus,
+//! };
+//! use async_trait::async_trait;
+//! use bytes::Bytes;
//! pub struct PingPong {
-//! pub client: AsyncClient,
+//! pub client: AsyncClient,
//! }
-//!
-//! impl EventHandler for PingPong{
-//! fn handle<'a>(&'a mut self, event: &'a packets::Packet) -> impl std::future::Future + Send + 'a {
-//! async move{
-//! match event{
-//! Packet::Publish(p) => {
-//! if let Ok(payload) = String::from_utf8(p.payload.to_vec()){
-//! if payload.to_lowercase().contains("ping"){
-//! self.client.publish(p.qos, p.retain, p.topic.clone(), Bytes::from_static(b"pong")).await;
-//! println!("Received Ping, Send pong!");
-//! }
+//! #[async_trait]
+//! impl AsyncEventHandlerMut for PingPong {
+//! async fn handle(&mut self, event: &packets::Packet) -> () {
+//! match event {
+//! Packet::Publish(p) => {
+//! if let Ok(payload) = String::from_utf8(p.payload.to_vec()) {
+//! if payload.to_lowercase().contains("ping") {
+//! self.client
+//! .publish(
+//! p.qos,
+//! p.retain,
+//! p.topic.clone(),
+//! Bytes::from_static(b"pong"),
+//! )
+//! .await
+//! .unwrap();
+//! println!("Received Ping, Send pong!");
//! }
-//! },
-//! Packet::ConnAck(_) => {
-//! println!("Connected!");
//! }
-//! _ => (),
-//! }
+//! },
+//! Packet::ConnAck(_) => { println!("Connected!") },
+//! _ => (),
//! }
//! }
//! }
-//!
-//! fn main(){
-//! let options = ConnectOptions::new("broker.emqx.io".to_string(), 8883, "mqrstt".to_string());
-//!
-//! let tls_config = RustlsConfig::Simple {
-//! ca: crate::tests::resources::EMQX_CERT.to_vec(),
-//! alpn: None,
-//! client_auth: None,
+//! smol::block_on(async {
+//! let options = ConnectOptions::new("mqrstt".to_string());
+//! let (mut network, mut handler, client) = new_smol(options);
+//! let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883))
+//! .await
+//! .unwrap();
+//! network.connect(stream).await.unwrap();
+//! client.subscribe("mqrstt").await.unwrap();
+//! let mut pingpong = PingPong {
+//! client: client.clone(),
//! };
-//!
-//! let (mut network, handler, client) = create_smol_rustls(options, tls_config);
-//! smol::block_on(async{
-//! client.subscribe("mqrstt").await.unwrap();
-//!
-//! let mut pingpong = PingPong{ client };
-//!
-//! join!(network.run(), handler.handle(&mut pingpong));
-//! });
-//! }
+//! let (n, h, t) = futures::join!(
+//! async {
+//! loop {
+//! return match network.run().await {
+//! Ok(NetworkStatus::Active) => continue,
+//! otherwise => otherwise,
+//! };
+//! }
+//! },
+//! async {
+//! loop {
+//! return match handler.handle_mut(&mut pingpong).await {
+//! Ok(HandlerStatus::Active) => continue,
+//! otherwise => otherwise,
+//! };
+//! }
+//! },
+//! async {
+//! smol::Timer::after(std::time::Duration::from_secs(60)).await;
+//! client.disconnect().await.unwrap();
+//! }
+//! );
+//! assert!(n.is_ok());
+//! assert!(h.is_ok());
+//! });
//! ```
-//!
-
-#![feature(async_fn_in_trait)]
-#![feature(return_position_impl_trait_in_trait)]
-
-use std::{sync::Arc, time::Instant};
-
-use async_mutex::Mutex;
use client::AsyncClient;
use connect_options::ConnectOptions;
-use connections::*;
-
-// #[cfg(all(feature = "tokio", feature = "tcp"))]
-// use connections::tokio_tcp::{TcpReader, TcpWriter};
-
-use connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite};
-
use event_handler::EventHandlerTask;
-use network::MqttNetwork;
+
+use packets::Packet;
+use smol_network::SmolNetwork;
mod available_packet_ids;
pub mod client;
@@ -121,69 +141,108 @@ pub mod connect_options;
pub mod connections;
pub mod error;
pub mod event_handler;
-pub mod network;
pub mod packets;
+pub mod smol_network;
pub mod state;
mod util;
#[cfg(test)]
pub mod tests;
+pub mod tokio_network;
+
+/// [`NetworkStatus`] Represents status of the Network object.
+/// It is returned when the run handle returns from performing an operation.
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum NetworkStatus {
+ Active,
+ IncomingDisconnect,
+ OutgoingDisconnect,
+ NoPingResp,
+}
-#[cfg(all(feature = "smol", feature = "smol-rustls"))]
-pub fn create_smol_rustls(
- mut options: ConnectOptions,
- tls_config: transport::RustlsConfig,
-) -> (
- MqttNetwork,
- EventHandlerTask,
- AsyncClient,
-) {
+/// [`HandlerStatus`] Represents status of the Network object.
+/// It is returned when the run handle returns from performing an operation.
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum HandlerStatus {
+ Active,
+ IncomingDisconnect,
+ OutgoingDisconnect,
+}
- options.tls_config = Some(transport::TlsConfig::Rustls(tls_config));
- new(options)
+#[async_trait::async_trait]
+/// Handlers are used to deal with packets before they are further processed (acked)
+/// This guarantees that the end user has handlded the packet.
+/// Trait for async mutable access to handler.
+/// Usefull when you have a single handler
+pub trait AsyncEventHandlerMut {
+ async fn handle(&mut self, event: &Packet);
}
-#[cfg(all(feature = "tokio", feature = "tokio-rustls"))]
-pub fn create_tokio_rustls(
- mut options: ConnectOptions,
- tls_config: transport::RustlsConfig,
-) -> (
- MqttNetwork,
- EventHandlerTask,
- AsyncClient,
-) {
+#[async_trait::async_trait]
+/// Handlers are used to deal with packets before they are further processed (acked)
+/// This guarantees that the end user has handlded the packet.
+/// Trait for async immutable access to handler.
+/// Usefull when you want to run multiple handlers concurrently to increase throughput.
+pub trait AsyncEventHandler {
+ async fn handle(&self, event: &Packet);
+}
- options.tls_config = Some(transport::TlsConfig::Rustls(tls_config));
- new(options)
+/// Handlers are used to deal with packets before they are further processed (acked)
+/// This guarantees that the end user has handlded the packet.
+/// Trait for sync mutable access to handler.
+/// Usefull when you want to run multiple handlers concurrently to increase throughput.
+pub trait EventHandlerMut {
+ fn handle(&mut self, event: &Packet);
}
-#[cfg(all(feature = "tokio", feature = "tcp"))]
-pub fn create_tokio_tcp(
- options: ConnectOptions,
-) -> (
- MqttNetwork,
- EventHandlerTask,
- AsyncClient,
-) {
- new(options)
+/// Handlers are used to deal with packets before they are further processed (acked)
+/// This guarantees that the end user has handlded the packet.
+/// Trait for sync immutable access to handler.
+/// Usefull when you want to run multiple handlers concurrently to increase throughput.
+pub trait EventHandler {
+ fn handle(&self, event: &Packet);
+}
+
+// #[cfg(all(feature = "smol", feature = "tokio"))]
+// std::compile_error!("The features smol and tokio can not be enabled simultaiously.");
+
+#[cfg(feature = "smol")]
+/// Creates the needed components to run the MQTT client using a stream that implements [`smol::io::AsyncReadExt`] and [`smol::io::AsyncWriteExt`]
+pub fn new_smol(options: ConnectOptions) -> (SmolNetwork, EventHandlerTask, AsyncClient)
+where
+ S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, {
+ let receive_maximum = options.receive_maximum();
+
+ let (to_network_s, to_network_r) = async_channel::bounded(100);
+ let (network_to_handler_s, network_to_handler_r) = async_channel::bounded(100);
+ let (client_to_handler_s, client_to_handler_r) =
+ async_channel::bounded(receive_maximum as usize);
+
+ let (handler, packet_ids) = EventHandlerTask::new(
+ &options,
+ network_to_handler_r,
+ to_network_s.clone(),
+ client_to_handler_r,
+ );
+
+ let network = SmolNetwork::::new(options, network_to_handler_s, to_network_r);
+
+ let client = AsyncClient::new(packet_ids, client_to_handler_s, to_network_s);
+
+ (network, handler, client)
}
-#[cfg(all(feature = "smol", feature = "tcp"))]
-pub fn create_smol_tcp(
+#[cfg(feature = "tokio")]
+/// Creates the needed components to run the MQTT client using a stream that implements [`tokio::io::AsyncReadExt`] and [`tokio::io::AsyncWriteExt`]
+pub fn new_tokio(
options: ConnectOptions,
) -> (
- MqttNetwork,
+ tokio_network::TokioNetwork,
EventHandlerTask,
AsyncClient,
-) {
- new(options)
-}
-
-pub fn new(options: ConnectOptions) -> (MqttNetwork, EventHandlerTask, AsyncClient)
+)
where
- R: AsyncMqttNetworkRead,
- W: AsyncMqttNetworkWrite,
-{
+ S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin, {
let receive_maximum = options.receive_maximum();
let (to_network_s, to_network_r) = async_channel::bounded(100);
@@ -191,103 +250,282 @@ where
let (client_to_handler_s, client_to_handler_r) =
async_channel::bounded(receive_maximum as usize);
- let last_network_action = Arc::new(Mutex::new(Instant::now()));
-
let (handler, packet_ids) = EventHandlerTask::new(
&options,
network_to_handler_r,
to_network_s.clone(),
- client_to_handler_r.clone(),
- last_network_action.clone(),
+ client_to_handler_r,
);
- let network = MqttNetwork::::new(
- options,
- network_to_handler_s,
- to_network_r,
- last_network_action,
- );
+ let network =
+ tokio_network::TokioNetwork::::new(options, network_to_handler_s, to_network_r);
let client = AsyncClient::new(packet_ids, client_to_handler_s, to_network_s);
(network, handler, client)
}
-
-#[cfg(all(test, any(feature = "smol-rustls", feature = "tokio-rustls")))]
-mod lib_test{
+#[cfg(test)]
+mod lib_test {
+ use std::time::Duration;
+
+ use crate::{
+ client::AsyncClient,
+ connect_options::ConnectOptions,
+ new_smol, new_tokio,
+ packets::{self, Packet},
+ util::tls::tests::simple_rust_tls,
+ AsyncEventHandlerMut, HandlerStatus, NetworkStatus,
+ };
+ use async_trait::async_trait;
use bytes::Bytes;
- use futures::join;
-
- use crate::{client::AsyncClient, packets::{self, Packet}};
- use crate::event_handler::AsyncEventHandler;
- use crate::create_smol_rustls;
- use crate::connections::transport::RustlsConfig;
- use crate::connect_options::ConnectOptions;
+ use rustls::ServerName;
pub struct PingPong {
pub client: AsyncClient,
}
-
- impl AsyncEventHandler for PingPong{
- fn handle<'a>(&'a mut self, event: &'a packets::Packet) -> impl std::future::Future + Send + 'a {
- async move{
- match event{
- Packet::Publish(p) => {
- if let Ok(payload) = String::from_utf8(p.payload.to_vec()){
- if payload.to_lowercase().contains("ping"){
- self.client.publish(p.qos, p.retain, p.topic.clone(), Bytes::from_static(b"pong")).await;
- println!("Received Ping, Send pong!");
- }
+
+ #[async_trait]
+ impl AsyncEventHandlerMut for PingPong {
+ async fn handle(&mut self, event: &packets::Packet) -> () {
+ match event {
+ Packet::Publish(p) => {
+ if let Ok(payload) = String::from_utf8(p.payload.to_vec()) {
+ if payload.to_lowercase().contains("ping") {
+ self.client
+ .publish(
+ p.qos,
+ p.retain,
+ p.topic.clone(),
+ Bytes::from_static(b"pong"),
+ )
+ .await
+ .unwrap();
+ println!("Received Ping, Send pong!");
}
- },
- Packet::ConnAck(_) => {
- println!("Connected!");
}
- _ => (),
}
+ _ => (),
}
}
}
-
- // #[test]
- fn test_smol(){
+ #[test]
+ fn test_smol_tcp() {
+ smol::block_on(async {
+ let options = ConnectOptions::new("SmolTcpPingPong".to_string());
- let filter = tracing_subscriber::filter::EnvFilter::new("none,mqrstt=trace");
+ let address = "broker.emqx.io";
+ let port = 1883;
- let subscriber = tracing_subscriber::FmtSubscriber::builder()
- .with_env_filter(filter)
- .with_max_level(tracing::Level::TRACE)
- .with_line_number(true)
- .finish();
+ let (mut network, mut handler, client) = new_smol(options);
- tracing::subscriber::set_global_default(subscriber)
- .expect("setting default subscriber failed");
+ let stream = smol::net::TcpStream::connect((address, port))
+ .await
+ .unwrap();
- smol::block_on(async{
- let options = ConnectOptions::new("broker.emqx.io".to_string(), 8883, "mqrstt".to_string());
-
- let tls_config = RustlsConfig::Simple {
- ca: crate::tests::resources::EMQX_CERT.to_vec(),
- alpn: None,
- client_auth: None,
+ network.connect(stream).await.unwrap();
+
+ client.subscribe("mqrstt").await.unwrap();
+
+ let mut pingpong = PingPong {
+ client: client.clone(),
};
-
- let (mut network, mut handler, client) = create_smol_rustls(options, tls_config);
-
+
+ let (n, h, _) = futures::join!(
+ async {
+ loop {
+ return match network.run().await {
+ Ok(NetworkStatus::Active) => continue,
+ otherwise => otherwise,
+ };
+ }
+ },
+ async {
+ loop {
+ return match handler.handle_mut(&mut pingpong).await {
+ Ok(HandlerStatus::Active) => continue,
+ otherwise => otherwise,
+ };
+ }
+ },
+ async {
+ smol::Timer::after(std::time::Duration::from_secs(60)).await;
+ client.disconnect().await.unwrap();
+ }
+ );
+ assert!(n.is_ok());
+ assert!(h.is_ok());
+ });
+ }
+
+ #[test]
+ fn test_smol_tls() {
+ smol::block_on(async {
+ let options = ConnectOptions::new("SmolTlsPingPong".to_string());
+
+ let address = "broker.emqx.io";
+ let port = 8883;
+
+ let (mut network, mut handler, client) = new_smol(options);
+
+ let arc_client_config =
+ simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap();
+
+ let domain = ServerName::try_from(address).unwrap();
+ let connector = async_rustls::TlsConnector::from(arc_client_config);
+
+ let stream = smol::net::TcpStream::connect((address, port))
+ .await
+ .unwrap();
+ let connection = connector.connect(domain, stream).await.unwrap();
+
+ network.connect(connection).await.unwrap();
+
client.subscribe("mqrstt").await.unwrap();
-
-
- let mut pingpong = PingPong{ client };
- join!(network.run(),
+ let mut pingpong = PingPong {
+ client: client.clone(),
+ };
+
+ let (n, h, _) = futures::join!(
async {
- loop{
- handler.handle(&mut pingpong).await.unwrap();
+ loop {
+ return match network.run().await {
+ Ok(NetworkStatus::Active) => continue,
+ otherwise => otherwise,
+ };
}
+ },
+ async {
+ loop {
+ return match handler.handle_mut(&mut pingpong).await {
+ Ok(HandlerStatus::Active) => continue,
+ otherwise => otherwise,
+ };
+ }
+ },
+ async {
+ smol::Timer::after(std::time::Duration::from_secs(60)).await;
+ client.disconnect().await.unwrap();
}
- ).0.unwrap();
+ );
+ assert!(n.is_ok());
+ assert!(h.is_ok());
});
}
-}
\ No newline at end of file
+
+ #[tokio::test]
+ async fn test_tokio_tcp() {
+ let options = ConnectOptions::new("TokioTcpPingPong".to_string());
+
+ let (mut network, mut handler, client) = new_tokio(options);
+
+ let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883))
+ .await
+ .unwrap();
+
+ network.connect(stream).await.unwrap();
+
+ client.subscribe("mqrstt").await.unwrap();
+
+ let mut pingpong = PingPong {
+ client: client.clone(),
+ };
+
+ let (n, h, _) = tokio::join!(
+ async {
+ loop {
+ return match network.run().await {
+ Ok(NetworkStatus::Active) => continue,
+ otherwise => otherwise,
+ };
+ }
+ },
+ async {
+ loop {
+ return match handler.handle_mut(&mut pingpong).await {
+ Ok(HandlerStatus::Active) => continue,
+ otherwise => otherwise,
+ };
+ }
+ },
+ async {
+ tokio::time::sleep(Duration::from_secs(60)).await;
+ client.disconnect().await.unwrap();
+ }
+ );
+ assert!(n.is_ok());
+ assert!(h.is_ok());
+
+ assert_eq!(NetworkStatus::OutgoingDisconnect, n.unwrap());
+ assert_eq!(HandlerStatus::OutgoingDisconnect, h.unwrap());
+ }
+
+ #[tokio::test]
+ async fn test_tokio_tls() {
+ let options = ConnectOptions::new("TokioTlsPingPong".to_string());
+
+ let address = "broker.emqx.io";
+ let port = 8883;
+
+ let (mut network, mut handler, client) = new_tokio(options);
+
+ let arc_client_config =
+ simple_rust_tls(crate::tests::resources::EMQX_CERT.to_vec(), None, None).unwrap();
+
+ let domain = ServerName::try_from(address).unwrap();
+ let connector = tokio_rustls::TlsConnector::from(arc_client_config);
+
+ let stream = tokio::net::TcpStream::connect((address, port))
+ .await
+ .unwrap();
+ let connection = connector.connect(domain, stream).await.unwrap();
+
+ network.connect(connection).await.unwrap();
+
+ client.subscribe("mqrstt").await.unwrap();
+
+ let mut pingpong = PingPong {
+ client: client.clone(),
+ };
+
+ let (n, h, _) = tokio::join!(
+ async {
+ loop {
+ return match network.run().await {
+ Ok(NetworkStatus::IncomingDisconnect) => {
+ Ok(NetworkStatus::IncomingDisconnect)
+ }
+ Ok(NetworkStatus::OutgoingDisconnect) => {
+ Ok(NetworkStatus::OutgoingDisconnect)
+ }
+ Ok(NetworkStatus::NoPingResp) => Ok(NetworkStatus::NoPingResp),
+ Ok(NetworkStatus::Active) => continue,
+ Err(a) => Err(a),
+ };
+ }
+ },
+ async {
+ loop {
+ return match handler.handle_mut(&mut pingpong).await {
+ Ok(HandlerStatus::IncomingDisconnect) => {
+ Ok(NetworkStatus::IncomingDisconnect)
+ }
+ Ok(HandlerStatus::OutgoingDisconnect) => {
+ Ok(NetworkStatus::OutgoingDisconnect)
+ }
+ Ok(HandlerStatus::Active) => continue,
+ Err(a) => Err(a),
+ };
+ }
+ },
+ async {
+ tokio::time::sleep(Duration::from_secs(60)).await;
+ client.disconnect().await.unwrap();
+ }
+ );
+ assert!(n.is_ok());
+ assert!(h.is_ok());
+ }
+}
diff --git a/src/network.rs b/src/network.rs
deleted file mode 100644
index 81a3510..0000000
--- a/src/network.rs
+++ /dev/null
@@ -1,121 +0,0 @@
-use std::sync::atomic::{AtomicBool, Ordering};
-use std::sync::Arc;
-
-use async_channel::{Receiver, Sender};
-use async_mutex::Mutex;
-
-use futures_concurrency::future::Join;
-use std::time::Instant;
-use tracing::trace;
-
-use crate::connect_options::ConnectOptions;
-use crate::connections::{AsyncMqttNetworkRead, AsyncMqttNetworkWrite};
-use crate::error::ConnectionError;
-use crate::packets::Packet;
-use crate::util::timeout::{timeout, self};
-
-pub type Incoming = Packet;
-pub type Outgoing = Packet;
-
-pub struct MqttNetwork {
- network: Option<(R, W)>,
-
- // write_buffer: BytesMut,
- /// Options of the current mqtt connection
- options: ConnectOptions,
-
- last_network_action: Arc>,
-
- network_to_handler_s: Sender,
- // incoming_packet_receiver: Receiver,
- // outgoing_packet_sender: Sender,
- to_network_r: Receiver,
-}
-
-impl MqttNetwork
-where
- R: AsyncMqttNetworkRead,
- W: AsyncMqttNetworkWrite,
-{
- pub fn new(
- options: ConnectOptions,
- network_to_handler_s: Sender,
- to_network_r: Receiver,
- last_network_action: Arc>,
- ) -> Self {
- Self {
- network: None,
-
- options,
-
- last_network_action,
-
- network_to_handler_s,
- to_network_r,
- }
- }
-
- pub fn reset(&mut self) {
- self.network = None;
- }
-
- pub async fn run(&mut self) -> Result<(), ConnectionError> {
- if self.network.is_none() {
- trace!("Creating network");
-
- // let con = R::connect(&self.options).await;
- let con = timeout::timeout(R::connect(&self.options), self.options.connection_timeout_s).await?;
-
- let (reader, writer, connack) = con?;
-
- trace!("Succesfully created network");
- self.network = Some((reader, writer));
- self.network_to_handler_s.send(connack).await?;
- }
-
- let MqttNetwork {
- network,
- options: _,
- last_network_action,
- network_to_handler_s,
- to_network_r,
- } = self;
-
- if let Some((reader, writer)) = network {
- let disconnect = AtomicBool::new(false);
-
- let incoming = async {
- loop {
- let local_disconnect = reader.read_direct(network_to_handler_s).await?;
- *(last_network_action.lock().await) = std::time::Instant::now();
- if local_disconnect {
- disconnect.store(true, Ordering::Release);
- return Ok(());
- } else if disconnect.load(Ordering::Acquire) {
- return Ok(());
- }
- }
- };
-
- let outgoing = async {
- loop {
- let local_disconnect = writer.write(to_network_r).await?;
- *(last_network_action.lock().await) = std::time::Instant::now();
- if local_disconnect {
- disconnect.store(true, Ordering::Release);
- return Ok(());
- } else if disconnect.load(Ordering::Acquire) {
- return Ok(());
- }
- }
- };
-
- let res: (Result<(), ConnectionError>, Result<(), ConnectionError>) =
- (incoming, outgoing).join().await;
- res.0?;
- res.1
- } else {
- Ok(())
- }
- }
-}
diff --git a/src/packets/auth.rs b/src/packets/auth.rs
index 86d56ba..04f47ce 100644
--- a/src/packets/auth.rs
+++ b/src/packets/auth.rs
@@ -67,7 +67,8 @@ impl MqttRead for AuthProperties {
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::MalformedPacket);
}
diff --git a/src/packets/connack.rs b/src/packets/connack.rs
index 1724ce1..6427054 100644
--- a/src/packets/connack.rs
+++ b/src/packets/connack.rs
@@ -123,7 +123,8 @@ impl MqttRead for ConnAckProperties {
let mut properties = Self::default();
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"ConnAckProperties".to_string(),
buf.len(),
diff --git a/src/packets/connect.rs b/src/packets/connect.rs
index 5e11ff4..297fb73 100644
--- a/src/packets/connect.rs
+++ b/src/packets/connect.rs
@@ -128,12 +128,14 @@ impl VariableHeaderRead for Connect {
let username = if connect_flags.contains(ConnectFlags::USERNAME) {
Some(String::read(&mut buf)?)
- } else {
+ }
+ else {
None
};
let password = if connect_flags.contains(ConnectFlags::PASSWORD) {
Some(String::read(&mut buf)?)
- } else {
+ }
+ else {
None
};
@@ -351,7 +353,8 @@ impl MqttRead for ConnectProperties {
let mut properties = Self::default();
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"ConnectProperties".to_string(),
buf.len(),
@@ -573,7 +576,8 @@ impl MqttRead for LastWillProperties {
let mut properties = Self::default();
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"LastWillProperties".to_string(),
buf.len(),
diff --git a/src/packets/disconnect.rs b/src/packets/disconnect.rs
index 254f6fc..908cbf0 100644
--- a/src/packets/disconnect.rs
+++ b/src/packets/disconnect.rs
@@ -34,7 +34,8 @@ impl VariableHeaderRead for Disconnect {
if remaining_length == 0 {
reason_code = DisconnectReasonCode::NormalDisconnection;
properties = DisconnectProperties::default();
- } else {
+ }
+ else {
reason_code = DisconnectReasonCode::read(&mut buf)?;
properties = DisconnectProperties::read(&mut buf)?;
}
@@ -62,7 +63,8 @@ impl WireLength for Disconnect {
|| self.properties.wire_len() != 0
{
self.properties.wire_len() + 1
- } else {
+ }
+ else {
0
}
}
@@ -83,7 +85,8 @@ impl MqttRead for DisconnectProperties {
let mut properties = Self::default();
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"DisconnectProperties".to_string(),
buf.len(),
diff --git a/src/packets/mod.rs b/src/packets/mod.rs
index f12c40f..3e90c4e 100644
--- a/src/packets/mod.rs
+++ b/src/packets/mod.rs
@@ -2,14 +2,13 @@ pub mod error;
pub mod mqtt_traits;
pub mod reason_codes;
-mod publish;
mod auth;
mod connack;
mod connect;
mod disconnect;
-mod packets;
mod puback;
mod pubcomp;
+mod publish;
mod pubrec;
mod pubrel;
mod suback;
@@ -17,28 +16,26 @@ mod subscribe;
mod unsuback;
mod unsubscribe;
-
pub use auth::*;
pub use connack::*;
pub use connect::*;
pub use disconnect::*;
pub use puback::*;
pub use pubcomp::*;
+pub use publish::*;
pub use pubrec::*;
pub use pubrel::*;
pub use suback::*;
pub use subscribe::*;
pub use unsuback::*;
pub use unsubscribe::*;
-pub use publish::*;
-
-pub use packets::*;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use core::slice::Iter;
+use std::fmt::Display;
use self::error::{DeserializeError, ReadBytes, SerializeError};
-use self::mqtt_traits::{MqttRead, MqttWrite, WireLength};
+use self::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength};
/// Protocol version
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
@@ -126,11 +123,14 @@ impl TryFrom for QoS {
fn try_from(c: ConnectFlags) -> Result {
if c.contains(ConnectFlags::WILL_QOS1 | ConnectFlags::WILL_QOS2) {
Err(DeserializeError::MalformedPacket)
- } else if c.contains(ConnectFlags::WILL_QOS2) {
+ }
+ else if c.contains(ConnectFlags::WILL_QOS2) {
Ok(QoS::ExactlyOnce)
- } else if c.contains(ConnectFlags::WILL_QOS1) {
+ }
+ else if c.contains(ConnectFlags::WILL_QOS1) {
Ok(QoS::AtLeastOnce)
- } else {
+ }
+ else {
Ok(QoS::AtMostOnce)
}
}
@@ -228,7 +228,8 @@ impl MqttWrite for bool {
if *self {
buf.put_u8(1);
Ok(())
- } else {
+ }
+ else {
buf.put_u8(0);
Ok(())
}
@@ -284,7 +285,8 @@ pub fn read_fixed_header_rem_len(
if (*byte & 0b1000_0000) == 0 {
return Ok((integer, length));
}
- } else {
+ }
+ else {
return Err(ReadBytes::InsufficientBytes(1));
}
}
@@ -335,11 +337,14 @@ pub fn write_variable_integer(buf: &mut BytesMut, integer: usize) -> Result<(),
pub fn variable_integer_len(integer: usize) -> usize {
if integer >= 2_097_152 {
4
- } else if integer >= 16_384 {
+ }
+ else if integer >= 16_384 {
3
- } else if integer >= 128 {
+ }
+ else if integer >= 128 {
2
- } else {
+ }
+ else {
1
}
}
@@ -520,3 +525,514 @@ impl PropertyType {
}
}
}
+
+// ==================== Packets ====================
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum Packet {
+ Connect(Connect),
+ ConnAck(ConnAck),
+ Publish(Publish),
+ PubAck(PubAck),
+ PubRec(PubRec),
+ PubRel(PubRel),
+ PubComp(PubComp),
+ Subscribe(Subscribe),
+ SubAck(SubAck),
+ Unsubscribe(Unsubscribe),
+ UnsubAck(UnsubAck),
+ PingReq,
+ PingResp,
+ Disconnect(Disconnect),
+ Auth(Auth),
+}
+
+impl Packet {
+ pub fn packet_type(&self) -> PacketType {
+ match self {
+ Packet::Connect(_) => PacketType::Connect,
+ Packet::ConnAck(_) => PacketType::ConnAck,
+ Packet::Publish(_) => PacketType::Publish,
+ Packet::PubAck(_) => PacketType::PubAck,
+ Packet::PubRec(_) => PacketType::PubRec,
+ Packet::PubRel(_) => PacketType::PubRel,
+ Packet::PubComp(_) => PacketType::PubComp,
+ Packet::Subscribe(_) => PacketType::Subscribe,
+ Packet::SubAck(_) => PacketType::SubAck,
+ Packet::Unsubscribe(_) => PacketType::Unsubscribe,
+ Packet::UnsubAck(_) => PacketType::UnsubAck,
+ Packet::PingReq => PacketType::PingReq,
+ Packet::PingResp => PacketType::PingResp,
+ Packet::Disconnect(_) => PacketType::Disconnect,
+ Packet::Auth(_) => PacketType::Auth,
+ }
+ }
+
+ pub fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
+ match self {
+ Packet::Connect(p) => {
+ buf.put_u8(0b0001_0000);
+ write_variable_integer(buf, p.wire_len())?;
+
+ p.write(buf)?;
+ }
+ Packet::ConnAck(_) => {
+ unreachable!()
+ }
+ Packet::Publish(p) => {
+ let mut first_byte = 0b0011_0000u8;
+ if p.dup {
+ first_byte |= 0b1000;
+ }
+
+ first_byte |= p.qos.into_u8() << 1;
+
+ if p.retain {
+ first_byte |= 0b0001;
+ }
+ buf.put_u8(first_byte);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::PubAck(p) => {
+ buf.put_u8(0b0100_0000);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::PubRec(p) => {
+ buf.put_u8(0b0101_0000);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::PubRel(p) => {
+ buf.put_u8(0b0110_0010);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::PubComp(p) => {
+ buf.put_u8(0b0111_0000);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::Subscribe(p) => {
+ buf.put_u8(0b1000_0010);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::SubAck(_) => {
+ unreachable!()
+ }
+ Packet::Unsubscribe(p) => {
+ buf.put_u8(0b1010_0010);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::UnsubAck(_) => {
+ buf.put_u8(0b1011_0000);
+ unreachable!()
+ }
+ Packet::PingReq => {
+ buf.put_u8(0b1100_0000);
+ buf.put_u8(0); // Variable header length.
+ }
+ Packet::PingResp => {
+ buf.put_u8(0b1101_0000);
+ buf.put_u8(0); // Variable header length.
+ }
+ Packet::Disconnect(p) => {
+ buf.put_u8(0b1110_0000);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ Packet::Auth(p) => {
+ buf.put_u8(0b1111_0000);
+ write_variable_integer(buf, p.wire_len())?;
+ p.write(buf)?;
+ }
+ }
+ Ok(())
+ }
+
+ pub fn read(header: FixedHeader, buf: Bytes) -> Result {
+ let packet = match header.packet_type {
+ PacketType::Connect => {
+ Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::ConnAck => {
+ Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::Publish => {
+ Packet::Publish(Publish::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::PubAck => {
+ Packet::PubAck(PubAck::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::PubRec => {
+ Packet::PubRec(PubRec::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::PubRel => {
+ Packet::PubRel(PubRel::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::PubComp => {
+ Packet::PubComp(PubComp::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::Subscribe => {
+ Packet::Subscribe(Subscribe::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::SubAck => {
+ Packet::SubAck(SubAck::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(
+ header.flags,
+ header.remaining_length,
+ buf,
+ )?),
+ PacketType::UnsubAck => {
+ Packet::UnsubAck(UnsubAck::read(header.flags, header.remaining_length, buf)?)
+ }
+ PacketType::PingReq => Packet::PingReq,
+ PacketType::PingResp => Packet::PingResp,
+ PacketType::Disconnect => Packet::Disconnect(Disconnect::read(
+ header.flags,
+ header.remaining_length,
+ buf,
+ )?),
+ PacketType::Auth => {
+ Packet::Auth(Auth::read(header.flags, header.remaining_length, buf)?)
+ }
+ };
+ Ok(packet)
+ }
+
+ pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> {
+ let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?;
+
+ if header.remaining_length + header_length > buffer.len() {
+ return Err(ReadBytes::InsufficientBytes(
+ header.remaining_length - buffer.len(),
+ ));
+ }
+ buffer.advance(header_length);
+
+ let buf = buffer.split_to(header.remaining_length);
+
+ Ok(Packet::read(header, buf.into())?)
+ }
+}
+
+impl Display for Packet {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self{
+ Packet::Connect(c) => write!(f, "Connect(version: {:?}, clean: {}, username: {:?}, password: {:?}, keep_alive: {}, client_id: {})", c.protocol_version, c.clean_session, c.username, c.password, c.keep_alive, c.client_id),
+ Packet::ConnAck(c) => write!(f, "ConnAck(session:{:?}, reason code{:?})", c.connack_flags, c.reason_code),
+ Packet::Publish(p) => write!(f, "Publish(topic: {}, qos: {:?}, dup: {:?}, retain: {:?}, packet id: {:?})", &p.topic, p.qos, p.dup, p.retain, p.packet_identifier),
+ Packet::PubAck(p) => write!(f, "PubAck(id:{:?}, reason code: {:?})", p.packet_identifier, p.reason_code),
+ Packet::PubRec(p) => write!(f, "PubRec(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code),
+ Packet::PubRel(p) => write!(f, "PubRel(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code),
+ Packet::PubComp(p) => write!(f, "PubComp(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code),
+ Packet::Subscribe(_) => write!(f, "Subscribe()"),
+ Packet::SubAck(_) => write!(f, "SubAck()"),
+ Packet::Unsubscribe(_) => write!(f, "Unsubscribe()"),
+ Packet::UnsubAck(_) => write!(f, "UnsubAck()"),
+ Packet::PingReq => write!(f, "PingReq"),
+ Packet::PingResp => write!(f, "PingResp"),
+ Packet::Disconnect(d) => write!(f, "Disconnect(reason code: {:?})", d.reason_code),
+ Packet::Auth(_) => write!(f, "Auth()"),
+ }
+ }
+}
+
+// 2.1.1 Fixed Header
+// ```
+// 7 3 0
+// +--------------------------+--------------------------+
+// byte 1 | MQTT Control Packet Type | Flags for Packet type |
+// +--------------------------+--------------------------+
+// | Remaining Length |
+// +-----------------------------------------------------+
+//
+// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021
+// ```
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
+pub struct FixedHeader {
+ pub packet_type: PacketType,
+ pub flags: u8,
+ pub remaining_length: usize,
+}
+
+impl FixedHeader {
+ pub fn read_fixed_header(
+ mut header: Iter,
+ ) -> Result<(Self, usize), ReadBytes> {
+ if header.len() < 2 {
+ return Err(ReadBytes::InsufficientBytes(2 - header.len()));
+ }
+
+ let first_byte = header.next().unwrap();
+
+ let (packet_type, flags) =
+ PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?;
+
+ let (remaining_length, length) = read_fixed_header_rem_len(header)?;
+
+ Ok((
+ Self {
+ packet_type,
+ flags,
+ remaining_length,
+ },
+ 1 + length,
+ ))
+ }
+}
+
+/// 2.1.2 MQTT Control Packet type
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
+pub enum PacketType {
+ Connect,
+ ConnAck,
+ Publish,
+ PubAck,
+ PubRec,
+ PubRel,
+ PubComp,
+ Subscribe,
+ SubAck,
+ Unsubscribe,
+ UnsubAck,
+ PingReq,
+ PingResp,
+ Disconnect,
+ Auth,
+}
+impl PacketType {
+ const fn from_first_byte(value: u8) -> Result<(Self, u8), DeserializeError> {
+ match (value >> 4, value & 0x0f) {
+ (0b0001, 0) => Ok((PacketType::Connect, 0)),
+ (0b0010, 0) => Ok((PacketType::ConnAck, 0)),
+ (0b0011, flags) => Ok((PacketType::Publish, flags)),
+ (0b0100, 0) => Ok((PacketType::PubAck, 0)),
+ (0b0101, 0) => Ok((PacketType::PubRec, 0)),
+ (0b0110, 0b0010) => Ok((PacketType::PubRel, 0)),
+ (0b0111, 0) => Ok((PacketType::PubComp, 0)),
+ (0b1000, 0b0010) => Ok((PacketType::Subscribe, 0)),
+ (0b1001, 0) => Ok((PacketType::SubAck, 0)),
+ (0b1010, 0b0010) => Ok((PacketType::Unsubscribe, 0)),
+ (0b1011, 0) => Ok((PacketType::UnsubAck, 0)),
+ (0b1100, 0) => Ok((PacketType::PingReq, 0)),
+ (0b1101, 0) => Ok((PacketType::PingResp, 0)),
+ (0b1110, 0) => Ok((PacketType::Disconnect, 0)),
+ (0b1111, 0) => Ok((PacketType::Auth, 0)),
+ (_, _) => Err(DeserializeError::UnknownFixedHeader(value)),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use bytes::{Bytes, BytesMut};
+
+ use crate::packets::connack::{ConnAck, ConnAckFlags, ConnAckProperties};
+ use crate::packets::disconnect::{Disconnect, DisconnectProperties};
+ use crate::packets::QoS;
+
+ use crate::packets::publish::{Publish, PublishProperties};
+ use crate::packets::pubrel::{PubRel, PubRelProperties};
+ use crate::packets::reason_codes::{ConnAckReasonCode, DisconnectReasonCode, PubRelReasonCode};
+ use crate::packets::Packet;
+
+ #[test]
+ fn test_connack_read() {
+ let connack = [
+ 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01,
+ 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01,
+ ];
+ let mut buf = BytesMut::new();
+ buf.extend(connack);
+
+ let res = Packet::read_from_buffer(&mut buf);
+ assert!(res.is_ok());
+ let res = res.unwrap();
+
+ let expected = ConnAck {
+ connack_flags: ConnAckFlags::SESSION_PRESENT,
+ reason_code: ConnAckReasonCode::Success,
+ connack_properties: ConnAckProperties {
+ session_expiry_interval: None,
+ receive_maximum: None,
+ maximum_qos: None,
+ retain_available: Some(true),
+ maximum_packet_size: Some(1048576),
+ assigned_client_id: None,
+ topic_alias_maximum: Some(65535),
+ reason_string: None,
+ user_properties: vec![],
+ wildcards_available: Some(true),
+ subscription_ids_available: Some(true),
+ shared_subscription_available: Some(true),
+ server_keep_alive: None,
+ response_info: None,
+ server_reference: None,
+ authentication_method: None,
+ authentication_data: None,
+ },
+ };
+
+ assert_eq!(Packet::ConnAck(expected), res);
+ }
+
+ #[test]
+ fn test_disconnect_read() {
+ let packet = [0xe0, 0x02, 0x8e, 0x00];
+ let mut buf = BytesMut::new();
+ buf.extend(packet);
+
+ let res = Packet::read_from_buffer(&mut buf);
+ assert!(res.is_ok());
+ let res = res.unwrap();
+
+ let expected = Disconnect {
+ reason_code: DisconnectReasonCode::SessionTakenOver,
+ properties: DisconnectProperties {
+ session_expiry_interval: None,
+ reason_string: None,
+ user_properties: vec![],
+ server_reference: None,
+ },
+ };
+
+ assert_eq!(Packet::Disconnect(expected), res);
+ }
+
+ #[test]
+ fn test_pingreq_read_write() {
+ let packet = [0xc0, 0x00];
+ let mut buf = BytesMut::new();
+ buf.extend(packet);
+
+ let res = Packet::read_from_buffer(&mut buf);
+ assert!(res.is_ok());
+ let res = res.unwrap();
+
+ assert_eq!(Packet::PingReq, res);
+
+ buf.clear();
+ Packet::PingReq.write(&mut buf).unwrap();
+ assert_eq!(buf.to_vec(), packet);
+ }
+
+ #[test]
+ fn test_pingresp_read_write() {
+ let packet = [0xd0, 0x00];
+ let mut buf = BytesMut::new();
+ buf.extend(packet);
+
+ let res = Packet::read_from_buffer(&mut buf);
+ assert!(res.is_ok());
+ let res = res.unwrap();
+
+ assert_eq!(Packet::PingResp, res);
+
+ buf.clear();
+ Packet::PingResp.write(&mut buf).unwrap();
+ assert_eq!(buf.to_vec(), packet);
+ }
+
+ #[test]
+ fn test_publish_read() {
+ let packet = [
+ 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74,
+ 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01,
+ 0x01, 0x09, 0x00, 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01,
+ ];
+
+ let mut buf = BytesMut::new();
+ buf.extend(packet);
+
+ let res = Packet::read_from_buffer(&mut buf);
+ assert!(res.is_ok());
+ let res = res.unwrap();
+
+ let expected = Publish {
+ dup: false,
+ qos: QoS::ExactlyOnce,
+ retain: true,
+ topic: "test/123/test/blabla".to_string(),
+ packet_identifier: Some(13779),
+ publish_properties: PublishProperties {
+ payload_format_indicator: Some(1),
+ message_expiry_interval: None,
+ topic_alias: None,
+ response_topic: None,
+ correlation_data: Some(Bytes::from_static(b"1212")),
+ subscription_identifier: vec![1],
+ user_properties: vec![],
+ content_type: None,
+ },
+ payload: Bytes::from_static(b""),
+ };
+
+ assert_eq!(Packet::Publish(expected), res);
+ }
+
+ #[test]
+ fn test_pubrel_read_write() {
+ let bytes = [0x62, 0x03, 0x35, 0xd3, 0x00];
+
+ let mut buffer = BytesMut::from_iter(bytes);
+
+ let res = Packet::read_from_buffer(&mut buffer);
+
+ assert!(res.is_ok());
+
+ let packet = res.unwrap();
+
+ let expected = PubRel {
+ packet_identifier: 13779,
+ reason_code: PubRelReasonCode::Success,
+ properties: PubRelProperties {
+ reason_string: None,
+ user_properties: vec![],
+ },
+ };
+
+ assert_eq!(packet, Packet::PubRel(expected));
+
+ buffer.clear();
+
+ packet.write(&mut buffer).unwrap();
+
+ // The input is not in the smallest possible format but when writing we do expect it to be in the smallest possible format.
+ assert_eq!(buffer.to_vec(), [0x62, 0x02, 0x35, 0xd3].to_vec())
+ }
+
+ #[test]
+ fn test_pubrel_read_smallest_format() {
+ let bytes = [0x62, 0x02, 0x35, 0xd3];
+
+ let mut buffer = BytesMut::from_iter(bytes);
+
+ let res = Packet::read_from_buffer(&mut buffer);
+
+ assert!(res.is_ok());
+
+ let packet = res.unwrap();
+
+ let expected = PubRel {
+ packet_identifier: 13779,
+ reason_code: PubRelReasonCode::Success,
+ properties: PubRelProperties {
+ reason_string: None,
+ user_properties: vec![],
+ },
+ };
+
+ assert_eq!(packet, Packet::PubRel(expected));
+
+ buffer.clear();
+
+ packet.write(&mut buffer).unwrap();
+
+ assert_eq!(buffer.to_vec(), bytes.to_vec())
+ }
+}
diff --git a/src/packets/packets.rs b/src/packets/packets.rs
deleted file mode 100644
index 8831eaa..0000000
--- a/src/packets/packets.rs
+++ /dev/null
@@ -1,531 +0,0 @@
-use core::slice::Iter;
-use std::fmt::Display;
-
-use bytes::{Buf, BufMut, Bytes, BytesMut};
-
-use super::error::{DeserializeError, ReadBytes, SerializeError};
-use super::mqtt_traits::{VariableHeaderRead, VariableHeaderWrite, WireLength};
-use super::{read_fixed_header_rem_len, write_variable_integer};
-
-use super::auth::Auth;
-use super::connack::ConnAck;
-use super::connect::Connect;
-use super::disconnect::Disconnect;
-use super::puback::PubAck;
-use super::pubcomp::PubComp;
-use super::publish::Publish;
-use super::pubrec::PubRec;
-use super::pubrel::PubRel;
-use super::suback::SubAck;
-use super::subscribe::Subscribe;
-use super::unsuback::UnsubAck;
-use super::unsubscribe::Unsubscribe;
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum Packet {
- Connect(Connect),
- ConnAck(ConnAck),
- Publish(Publish),
- PubAck(PubAck),
- PubRec(PubRec),
- PubRel(PubRel),
- PubComp(PubComp),
- Subscribe(Subscribe),
- SubAck(SubAck),
- Unsubscribe(Unsubscribe),
- UnsubAck(UnsubAck),
- PingReq,
- PingResp,
- Disconnect(Disconnect),
- Auth(Auth),
-}
-
-impl Packet {
- pub fn packet_type(&self) -> PacketType {
- match self {
- Packet::Connect(_) => PacketType::Connect,
- Packet::ConnAck(_) => PacketType::ConnAck,
- Packet::Publish(_) => PacketType::Publish,
- Packet::PubAck(_) => PacketType::PubAck,
- Packet::PubRec(_) => PacketType::PubRec,
- Packet::PubRel(_) => PacketType::PubRel,
- Packet::PubComp(_) => PacketType::PubComp,
- Packet::Subscribe(_) => PacketType::Subscribe,
- Packet::SubAck(_) => PacketType::SubAck,
- Packet::Unsubscribe(_) => PacketType::Unsubscribe,
- Packet::UnsubAck(_) => PacketType::UnsubAck,
- Packet::PingReq => PacketType::PingReq,
- Packet::PingResp => PacketType::PingResp,
- Packet::Disconnect(_) => PacketType::Disconnect,
- Packet::Auth(_) => PacketType::Auth,
- }
- }
-
- pub fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
- match self {
- Packet::Connect(p) => {
- buf.put_u8(0b0001_0000);
- write_variable_integer(buf, p.wire_len())?;
-
- p.write(buf)?;
- }
- Packet::ConnAck(_) => {
- unreachable!()
- }
- Packet::Publish(p) => {
- let mut first_byte = 0b0011_0000u8;
- if p.dup {
- first_byte |= 0b1000;
- }
-
- first_byte |= p.qos.into_u8() << 1;
-
- if p.retain {
- first_byte |= 0b0001;
- }
- buf.put_u8(first_byte);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::PubAck(p) => {
- buf.put_u8(0b0100_0000);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::PubRec(p) => {
- buf.put_u8(0b0101_0000);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::PubRel(p) => {
- buf.put_u8(0b0110_0010);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::PubComp(p) => {
- buf.put_u8(0b0111_0000);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::Subscribe(p) => {
- buf.put_u8(0b1000_0010);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::SubAck(_) => {
- unreachable!()
- }
- Packet::Unsubscribe(p) => {
- buf.put_u8(0b1010_0010);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::UnsubAck(_) => {
- buf.put_u8(0b1011_0000);
- unreachable!()
- }
- Packet::PingReq => {
- buf.put_u8(0b1100_0000);
- buf.put_u8(0); // Variable header length.
- }
- Packet::PingResp => {
- buf.put_u8(0b1101_0000);
- buf.put_u8(0); // Variable header length.
- }
- Packet::Disconnect(p) => {
- buf.put_u8(0b1110_0000);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- Packet::Auth(p) => {
- buf.put_u8(0b1111_0000);
- write_variable_integer(buf, p.wire_len())?;
- p.write(buf)?;
- }
- }
- Ok(())
- }
-
- pub fn read(header: FixedHeader, buf: Bytes) -> Result {
- let packet = match header.packet_type {
- PacketType::Connect => {
- Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::ConnAck => {
- Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::Publish => {
- Packet::Publish(Publish::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::PubAck => {
- Packet::PubAck(PubAck::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::PubRec => {
- Packet::PubRec(PubRec::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::PubRel => {
- Packet::PubRel(PubRel::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::PubComp => {
- Packet::PubComp(PubComp::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::Subscribe => {
- Packet::Subscribe(Subscribe::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::SubAck => {
- Packet::SubAck(SubAck::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::read(
- header.flags,
- header.remaining_length,
- buf,
- )?),
- PacketType::UnsubAck => {
- Packet::UnsubAck(UnsubAck::read(header.flags, header.remaining_length, buf)?)
- }
- PacketType::PingReq => Packet::PingReq,
- PacketType::PingResp => Packet::PingResp,
- PacketType::Disconnect => Packet::Disconnect(Disconnect::read(
- header.flags,
- header.remaining_length,
- buf,
- )?),
- PacketType::Auth => {
- Packet::Auth(Auth::read(header.flags, header.remaining_length, buf)?)
- }
- };
- Ok(packet)
- }
-
- pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> {
- let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?;
-
- if header.remaining_length + header_length > buffer.len() {
- return Err(ReadBytes::InsufficientBytes(
- header.remaining_length - buffer.len(),
- ));
- }
- buffer.advance(header_length);
-
- let buf = buffer.split_to(header.remaining_length);
-
- return Ok(Packet::read(header, buf.into())?);
- }
-}
-
-impl Display for Packet {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self{
- Packet::Connect(c) => write!(f, "Connect(version: {:?}, clean: {}, username: {:?}, password: {:?}, keep_alive: {}, client_id: {})", c.protocol_version, c.clean_session, c.username, c.password, c.keep_alive, c.client_id),
- Packet::ConnAck(c) => write!(f, "ConnAck(session:{:?}, reason code{:?})", c.connack_flags, c.reason_code),
- Packet::Publish(p) => write!(f, "Publish(topic: {}, qos: {:?}, dup: {:?}, retain: {:?}, packet id: {:?})", &p.topic, p.qos, p.dup, p.retain, p.packet_identifier),
- Packet::PubAck(p) => write!(f, "PubAck(id:{:?}, reason code: {:?})", p.packet_identifier, p.reason_code),
- Packet::PubRec(p) => write!(f, "PubRec(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code),
- Packet::PubRel(p) => write!(f, "PubRel(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code),
- Packet::PubComp(p) => write!(f, "PubComp(id: {}, reason code: {:?})", p.packet_identifier, p.reason_code),
- Packet::Subscribe(_) => write!(f, "Subscribe()"),
- Packet::SubAck(_) => write!(f, "SubAck()"),
- Packet::Unsubscribe(_) => write!(f, "Unsubscribe()"),
- Packet::UnsubAck(_) => write!(f, "UnsubAck()"),
- Packet::PingReq => write!(f, "PingReq"),
- Packet::PingResp => write!(f, "PingResp"),
- Packet::Disconnect(d) => write!(f, "Disconnect(reason code: {:?})", d.reason_code),
- Packet::Auth(_) => write!(f, "Auth()"),
- }
- }
-}
-
-// 2.1.1 Fixed Header
-// ```
-// 7 3 0
-// +--------------------------+--------------------------+
-// byte 1 | MQTT Control Packet Type | Flags for Packet type |
-// +--------------------------+--------------------------+
-// | Remaining Length |
-// +-----------------------------------------------------+
-//
-// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021
-// ```
-#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
-pub struct FixedHeader {
- pub packet_type: PacketType,
- pub flags: u8,
- pub remaining_length: usize,
-}
-
-impl FixedHeader {
- pub fn read_fixed_header(
- mut header: Iter,
- ) -> Result<(Self, usize), ReadBytes> {
- if header.len() < 2 {
- return Err(ReadBytes::InsufficientBytes(2 - header.len()));
- }
-
- let first_byte = header.next().unwrap();
-
- let (packet_type, flags) =
- PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?;
-
- let (remaining_length, length) = read_fixed_header_rem_len(header)?;
-
- Ok((
- Self {
- packet_type,
- flags,
- remaining_length,
- },
- 1 + length,
- ))
- }
-}
-
-/// 2.1.2 MQTT Control Packet type
-#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
-pub enum PacketType {
- Connect,
- ConnAck,
- Publish,
- PubAck,
- PubRec,
- PubRel,
- PubComp,
- Subscribe,
- SubAck,
- Unsubscribe,
- UnsubAck,
- PingReq,
- PingResp,
- Disconnect,
- Auth,
-}
-impl PacketType {
- const fn from_first_byte(value: u8) -> Result<(Self, u8), DeserializeError> {
- match (value >> 4, value & 0x0f) {
- (0b0001, 0) => Ok((PacketType::Connect, 0)),
- (0b0010, 0) => Ok((PacketType::ConnAck, 0)),
- (0b0011, flags) => Ok((PacketType::Publish, flags)),
- (0b0100, 0) => Ok((PacketType::PubAck, 0)),
- (0b0101, 0) => Ok((PacketType::PubRec, 0)),
- (0b0110, 0b0010) => Ok((PacketType::PubRel, 0)),
- (0b0111, 0) => Ok((PacketType::PubComp, 0)),
- (0b1000, 0b0010) => Ok((PacketType::Subscribe, 0)),
- (0b1001, 0) => Ok((PacketType::SubAck, 0)),
- (0b1010, 0b0010) => Ok((PacketType::Unsubscribe, 0)),
- (0b1011, 0) => Ok((PacketType::UnsubAck, 0)),
- (0b1100, 0) => Ok((PacketType::PingReq, 0)),
- (0b1101, 0) => Ok((PacketType::PingResp, 0)),
- (0b1110, 0) => Ok((PacketType::Disconnect, 0)),
- (0b1111, 0) => Ok((PacketType::Auth, 0)),
- (_, _) => Err(DeserializeError::UnknownFixedHeader(value)),
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use bytes::{Bytes, BytesMut};
-
- use crate::packets::connack::{ConnAck, ConnAckFlags, ConnAckProperties};
- use crate::packets::disconnect::{Disconnect, DisconnectProperties};
- use crate::packets::QoS;
-
- use crate::packets::packets::Packet;
- use crate::packets::publish::{Publish, PublishProperties};
- use crate::packets::pubrel::{PubRel, PubRelProperties};
- use crate::packets::reason_codes::{ConnAckReasonCode, DisconnectReasonCode, PubRelReasonCode};
-
- #[test]
- fn test_connack_read() {
- let connack = [
- 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01,
- 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01,
- ];
- let mut buf = BytesMut::new();
- buf.extend(connack);
-
- let res = Packet::read_from_buffer(&mut buf);
- assert!(res.is_ok());
- let res = res.unwrap();
-
- let expected = ConnAck {
- connack_flags: ConnAckFlags::SESSION_PRESENT,
- reason_code: ConnAckReasonCode::Success,
- connack_properties: ConnAckProperties {
- session_expiry_interval: None,
- receive_maximum: None,
- maximum_qos: None,
- retain_available: Some(true),
- maximum_packet_size: Some(1048576),
- assigned_client_id: None,
- topic_alias_maximum: Some(65535),
- reason_string: None,
- user_properties: vec![],
- wildcards_available: Some(true),
- subscription_ids_available: Some(true),
- shared_subscription_available: Some(true),
- server_keep_alive: None,
- response_info: None,
- server_reference: None,
- authentication_method: None,
- authentication_data: None,
- },
- };
-
- assert_eq!(Packet::ConnAck(expected), res);
- }
-
- #[test]
- fn test_disconnect_read() {
- let packet = [0xe0, 0x02, 0x8e, 0x00];
- let mut buf = BytesMut::new();
- buf.extend(packet);
-
- let res = Packet::read_from_buffer(&mut buf);
- assert!(res.is_ok());
- let res = res.unwrap();
-
- let expected = Disconnect {
- reason_code: DisconnectReasonCode::SessionTakenOver,
- properties: DisconnectProperties {
- session_expiry_interval: None,
- reason_string: None,
- user_properties: vec![],
- server_reference: None,
- },
- };
-
- assert_eq!(Packet::Disconnect(expected), res);
- }
-
- #[test]
- fn test_pingreq_read_write() {
- let packet = [0xc0, 0x00];
- let mut buf = BytesMut::new();
- buf.extend(packet);
-
- let res = Packet::read_from_buffer(&mut buf);
- assert!(res.is_ok());
- let res = res.unwrap();
-
- assert_eq!(Packet::PingReq, res);
-
- buf.clear();
- Packet::PingReq.write(&mut buf).unwrap();
- assert_eq!(buf.to_vec(), packet);
- }
-
- #[test]
- fn test_pingresp_read_write() {
- let packet = [0xd0, 0x00];
- let mut buf = BytesMut::new();
- buf.extend(packet);
-
- let res = Packet::read_from_buffer(&mut buf);
- assert!(res.is_ok());
- let res = res.unwrap();
-
- assert_eq!(Packet::PingResp, res);
-
- buf.clear();
- Packet::PingResp.write(&mut buf).unwrap();
- assert_eq!(buf.to_vec(), packet);
- }
-
- #[test]
- fn test_publish_read() {
- let packet = [
- 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74,
- 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01,
- 0x01, 0x09, 0x00, 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01,
- ];
-
- let mut buf = BytesMut::new();
- buf.extend(packet);
-
- let res = Packet::read_from_buffer(&mut buf);
- assert!(res.is_ok());
- let res = res.unwrap();
-
- let expected = Publish {
- dup: false,
- qos: QoS::ExactlyOnce,
- retain: true,
- topic: "test/123/test/blabla".to_string(),
- packet_identifier: Some(13779),
- publish_properties: PublishProperties {
- payload_format_indicator: Some(1),
- message_expiry_interval: None,
- topic_alias: None,
- response_topic: None,
- correlation_data: Some(Bytes::from_static(b"1212")),
- subscription_identifier: vec![1],
- user_properties: vec![],
- content_type: None,
- },
- payload: Bytes::from_static(b""),
- };
-
- assert_eq!(Packet::Publish(expected), res);
- }
-
- #[test]
- fn test_pubrel_read_write() {
- let bytes = [0x62, 0x03, 0x35, 0xd3, 0x00];
-
- let mut buffer = BytesMut::from_iter(&bytes);
-
- let res = Packet::read_from_buffer(&mut buffer);
-
- assert!(res.is_ok());
-
- let packet = res.unwrap();
-
- let expected = PubRel {
- packet_identifier: 13779,
- reason_code: PubRelReasonCode::Success,
- properties: PubRelProperties {
- reason_string: None,
- user_properties: vec![],
- },
- };
-
- assert_eq!(packet, Packet::PubRel(expected));
-
- buffer.clear();
-
- packet.write(&mut buffer).unwrap();
-
- // The input is not in the smallest possible format but when writing we do expect it to be in the smallest possible format.
- assert_eq!(buffer.to_vec(), [0x62, 0x02, 0x35, 0xd3].to_vec())
- }
-
- #[test]
- fn test_pubrel_read_smallest_format() {
- let bytes = [0x62, 0x02, 0x35, 0xd3];
-
- let mut buffer = BytesMut::from_iter(&bytes);
-
- let res = Packet::read_from_buffer(&mut buffer);
-
- assert!(res.is_ok());
-
- let packet = res.unwrap();
-
- let expected = PubRel {
- packet_identifier: 13779,
- reason_code: PubRelReasonCode::Success,
- properties: PubRelProperties {
- reason_string: None,
- user_properties: vec![],
- },
- };
-
- assert_eq!(packet, Packet::PubRel(expected));
-
- buffer.clear();
-
- packet.write(&mut buffer).unwrap();
-
- assert_eq!(buffer.to_vec(), bytes.to_vec())
- }
-}
diff --git a/src/packets/puback.rs b/src/packets/puback.rs
index bf22d0d..493b3ca 100644
--- a/src/packets/puback.rs
+++ b/src/packets/puback.rs
@@ -56,13 +56,15 @@ impl VariableHeaderWrite for PubAck {
if self.reason_code == PubAckReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty() {
- ()
+ && self.properties.user_properties.is_empty()
+ {
+ // nothing here
}
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
self.reason_code.write(buf)?;
- }
+ }
else {
self.reason_code.write(buf)?;
self.properties.write(buf)?;
@@ -75,11 +77,13 @@ impl WireLength for PubAck {
fn wire_len(&self) -> usize {
if self.reason_code == PubAckReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
2
- }
+ }
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
3
}
else {
@@ -184,7 +188,6 @@ mod tests {
};
use bytes::{BufMut, Bytes, BytesMut};
-
#[test]
fn test_wire_len() {
let mut puback = PubAck {
@@ -196,10 +199,10 @@ mod tests {
let mut buf = BytesMut::new();
puback.write(&mut buf).unwrap();
-
+
assert_eq!(2, puback.wire_len());
assert_eq!(2, buf.len());
-
+
puback.reason_code = PubAckReasonCode::NotAuthorized;
buf.clear();
puback.write(&mut buf).unwrap();
@@ -208,7 +211,6 @@ mod tests {
assert_eq!(3, buf.len());
}
-
#[test]
fn test_read_simple_puback() {
let stream = &[
diff --git a/src/packets/pubcomp.rs b/src/packets/pubcomp.rs
index ec60ede..b1fdfdd 100644
--- a/src/packets/pubcomp.rs
+++ b/src/packets/pubcomp.rs
@@ -66,13 +66,15 @@ impl VariableHeaderWrite for PubComp {
if self.reason_code == PubCompReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty() {
- ()
+ && self.properties.user_properties.is_empty()
+ {
+ // nothing here
}
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
self.reason_code.write(buf)?;
- }
+ }
else {
self.reason_code.write(buf)?;
self.properties.write(buf)?;
@@ -85,11 +87,13 @@ impl WireLength for PubComp {
fn wire_len(&self) -> usize {
if self.reason_code == PubCompReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
2
- }
+ }
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
3
}
else {
@@ -98,7 +102,7 @@ impl WireLength for PubComp {
}
}
-#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)]
pub struct PubCompProperties {
pub reason_string: Option,
pub user_properties: Vec<(String, String)>,
@@ -110,15 +114,6 @@ impl PubCompProperties {
}
}
-impl Default for PubCompProperties {
- fn default() -> Self {
- Self {
- reason_string: Default::default(),
- user_properties: Default::default(),
- }
- }
-}
-
impl MqttRead for PubCompProperties {
fn read(buf: &mut bytes::Bytes) -> Result {
let (len, _) = read_variable_integer(buf)?;
@@ -203,7 +198,6 @@ mod tests {
};
use bytes::{BufMut, Bytes, BytesMut};
-
#[test]
fn test_wire_len() {
let mut pubcomp = PubComp {
diff --git a/src/packets/publish.rs b/src/packets/publish.rs
index cbe7227..5262075 100644
--- a/src/packets/publish.rs
+++ b/src/packets/publish.rs
@@ -101,7 +101,8 @@ impl WireLength for Publish {
let len = self.topic.wire_len()
+ if self.packet_identifier.is_some() {
2
- } else {
+ }
+ else {
0
}
+ self.publish_properties.wire_len()
@@ -152,7 +153,8 @@ impl MqttRead for PublishProperties {
if len == 0 {
return Ok(Self::default());
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"PublishProperties".to_string(),
buf.len(),
diff --git a/src/packets/pubrec.rs b/src/packets/pubrec.rs
index 448ee1a..737463c 100644
--- a/src/packets/pubrec.rs
+++ b/src/packets/pubrec.rs
@@ -65,13 +65,15 @@ impl VariableHeaderWrite for PubRec {
if self.reason_code == PubRecReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty() {
- ()
+ && self.properties.user_properties.is_empty()
+ {
+ // nothing here
}
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
self.reason_code.write(buf)?;
- }
+ }
else {
self.reason_code.write(buf)?;
self.properties.write(buf)?;
@@ -84,11 +86,13 @@ impl WireLength for PubRec {
fn wire_len(&self) -> usize {
if self.reason_code == PubRecReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
2
- }
+ }
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
3
}
else {
diff --git a/src/packets/pubrel.rs b/src/packets/pubrel.rs
index 985ebbf..3c08f3b 100644
--- a/src/packets/pubrel.rs
+++ b/src/packets/pubrel.rs
@@ -38,13 +38,15 @@ impl VariableHeaderRead for PubRel {
reason_code: PubRelReasonCode::Success,
properties: PubRelProperties::default(),
})
- } else if remaining_length == 3 {
+ }
+ else if remaining_length == 3 {
Ok(Self {
packet_identifier: u16::read(&mut buf)?,
reason_code: PubRelReasonCode::read(&mut buf)?,
properties: PubRelProperties::default(),
})
- } else {
+ }
+ else {
Ok(Self {
packet_identifier: u16::read(&mut buf)?,
reason_code: PubRelReasonCode::read(&mut buf)?,
@@ -60,13 +62,15 @@ impl VariableHeaderWrite for PubRel {
if self.reason_code == PubRelReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty() {
- ()
+ && self.properties.user_properties.is_empty()
+ {
+ // Nothing here
}
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
self.reason_code.write(buf)?;
- }
+ }
else {
self.reason_code.write(buf)?;
self.properties.write(buf)?;
@@ -79,11 +83,13 @@ impl WireLength for PubRel {
fn wire_len(&self) -> usize {
if self.reason_code == PubRelReasonCode::Success
&& self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
2
- }
+ }
else if self.properties.reason_string.is_none()
- && self.properties.user_properties.is_empty(){
+ && self.properties.user_properties.is_empty()
+ {
3
}
else {
@@ -188,7 +194,6 @@ mod tests {
};
use bytes::{BufMut, Bytes, BytesMut};
-
#[test]
fn test_wire_len() {
let mut pubrel = PubRel {
@@ -227,18 +232,17 @@ mod tests {
assert_eq!(2, buf.len());
let pubrel = PubRel::read(0, 2, buf.into()).unwrap();
-
+
assert_eq!(expected_pubrel, pubrel);
-
+
let mut buf = BytesMut::new();
expected_pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound;
expected_pubrel.write(&mut buf).unwrap();
-
+
assert_eq!(3, buf.len());
-
+
let pubrel = PubRel::read(0, 3, buf.into()).unwrap();
assert_eq!(expected_pubrel, pubrel);
-
}
#[test]
diff --git a/src/packets/suback.rs b/src/packets/suback.rs
index f86fa40..31dc5bb 100644
--- a/src/packets/suback.rs
+++ b/src/packets/suback.rs
@@ -77,7 +77,8 @@ impl MqttRead for SubAckProperties {
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"SubAckProperties".to_string(),
buf.len(),
@@ -94,7 +95,8 @@ impl MqttRead for SubAckProperties {
let (subscription_id, _) = read_variable_integer(&mut properties_data)?;
properties.subscription_id = Some(subscription_id);
- } else {
+ }
+ else {
return Err(DeserializeError::DuplicateProperty(
PropertyType::SubscriptionIdentifier,
));
diff --git a/src/packets/subscribe.rs b/src/packets/subscribe.rs
index 2a9c12a..94a1b5f 100644
--- a/src/packets/subscribe.rs
+++ b/src/packets/subscribe.rs
@@ -96,7 +96,8 @@ impl MqttRead for SubscribeProperties {
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"SubscribeProperties".to_string(),
buf.len(),
@@ -113,7 +114,8 @@ impl MqttRead for SubscribeProperties {
let (subscription_id, _) = read_variable_integer(&mut properties_data)?;
properties.subscription_id = Some(subscription_id);
- } else {
+ }
+ else {
return Err(DeserializeError::DuplicateProperty(
PropertyType::SubscriptionIdentifier,
));
@@ -267,7 +269,7 @@ mod tests {
use crate::packets::{
mqtt_traits::{MqttRead, VariableHeaderRead, VariableHeaderWrite},
- packets::Packet,
+ Packet,
};
use super::WireLength;
diff --git a/src/packets/unsuback.rs b/src/packets/unsuback.rs
index da1aa1b..94536bf 100644
--- a/src/packets/unsuback.rs
+++ b/src/packets/unsuback.rs
@@ -71,7 +71,8 @@ impl MqttRead for UnsubAckProperties {
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"UnsubAckProperties".to_string(),
buf.len(),
@@ -86,7 +87,8 @@ impl MqttRead for UnsubAckProperties {
PropertyType::ReasonString => {
if properties.reason_string.is_none() {
properties.reason_string = Some(String::read(&mut properties_data)?);
- } else {
+ }
+ else {
return Err(DeserializeError::DuplicateProperty(
PropertyType::SubscriptionIdentifier,
));
diff --git a/src/packets/unsubscribe.rs b/src/packets/unsubscribe.rs
index db81eda..225410c 100644
--- a/src/packets/unsubscribe.rs
+++ b/src/packets/unsubscribe.rs
@@ -76,7 +76,8 @@ impl MqttRead for UnsubscribeProperties {
if len == 0 {
return Ok(properties);
- } else if buf.len() < len {
+ }
+ else if buf.len() < len {
return Err(DeserializeError::InsufficientData(
"UnsubscribeProperties".to_string(),
buf.len(),
diff --git a/src/smol_network.rs b/src/smol_network.rs
new file mode 100644
index 0000000..eab20e9
--- /dev/null
+++ b/src/smol_network.rs
@@ -0,0 +1,166 @@
+use async_channel::{Receiver, Sender};
+
+use futures::{select, FutureExt};
+use smol::io::{AsyncReadExt, AsyncWriteExt};
+
+use std::time::{Duration, Instant};
+
+use crate::connect_options::ConnectOptions;
+use crate::connections::smol_stream::SmolStream;
+use crate::error::ConnectionError;
+use crate::packets::error::ReadBytes;
+use crate::packets::{Packet, PacketType};
+use crate::NetworkStatus;
+
+pub struct SmolNetwork {
+ network: Option>,
+
+ /// Options of the current mqtt connection
+ options: ConnectOptions,
+
+ last_network_action: Instant,
+ await_pingresp: Option,
+ perform_keep_alive: bool,
+
+ network_to_handler_s: Sender,
+
+ to_network_r: Receiver,
+}
+
+impl SmolNetwork
+where
+ S: AsyncReadExt + AsyncWriteExt + Sized + Unpin,
+{
+ pub fn new(
+ options: ConnectOptions,
+ network_to_handler_s: Sender,
+ to_network_r: Receiver,
+ ) -> Self {
+ Self {
+ network: None,
+
+ options,
+
+ last_network_action: Instant::now(),
+ await_pingresp: None,
+ perform_keep_alive: true,
+
+ network_to_handler_s,
+ to_network_r,
+ }
+ }
+
+ pub fn reset(&mut self) {
+ self.network = None;
+ }
+
+ pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> {
+ let (network, connack) = SmolStream::connect(&self.options, stream).await?;
+
+ self.network = Some(network);
+ self.network_to_handler_s.send(connack).await?;
+ self.last_network_action = Instant::now();
+ if self.options.keep_alive_interval_s == 0 {
+ self.perform_keep_alive = false;
+ }
+ Ok(())
+ }
+
+ pub async fn run(&mut self) -> Result {
+ if self.network.is_none() {
+ return Err(ConnectionError::NoNetwork);
+ }
+
+ match self.select().await {
+ Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active),
+ otherwise => {
+ self.reset();
+ otherwise
+ }
+ }
+ }
+
+ async fn select(&mut self) -> Result {
+ if self.network.is_none() {
+ return Err(ConnectionError::NoNetwork);
+ }
+
+ let SmolNetwork {
+ network,
+ options: _,
+ last_network_action,
+ await_pingresp,
+ perform_keep_alive,
+ network_to_handler_s,
+ to_network_r,
+ } = self;
+
+ let sleep;
+ if !(*perform_keep_alive) {
+ sleep = Duration::new(3600, 0);
+ }
+ else if let Some(instant) = await_pingresp {
+ sleep =
+ *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now();
+ }
+ else {
+ sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s)
+ - Instant::now();
+ }
+
+ if let Some(stream) = network {
+ loop {
+ select! {
+ _ = stream.read_bytes().fuse() => {
+ match stream.parse_messages(network_to_handler_s).await {
+ Err(ReadBytes::Err(err)) => return Err(err),
+ Err(ReadBytes::InsufficientBytes(_)) => continue,
+ Ok(Some(PacketType::PingResp)) => {
+ *await_pingresp = None;
+ return Ok(NetworkStatus::Active)
+ },
+ Ok(Some(PacketType::Disconnect)) => {
+ return Ok(NetworkStatus::IncomingDisconnect)
+ },
+ Ok(_) => {
+ return Ok(NetworkStatus::Active)
+ }
+ };
+ },
+ outgoing = to_network_r.recv().fuse() => {
+ let packet = outgoing?;
+ stream.write(&packet).await?;
+ *last_network_action = Instant::now();
+ if packet.packet_type() == PacketType::Disconnect{
+ return Ok(NetworkStatus::OutgoingDisconnect);
+ }
+ return Ok(NetworkStatus::Active);
+ },
+ _ = smol::Timer::after(sleep).fuse() => {
+ if await_pingresp.is_none() && *perform_keep_alive{
+ let packet = Packet::PingReq;
+ stream.write(&packet).await?;
+ *last_network_action = Instant::now();
+ *await_pingresp = Some(Instant::now());
+ return Ok(NetworkStatus::Active);
+ }
+ else{
+ return Ok(NetworkStatus::NoPingResp);
+ }
+ },
+ }
+ }
+ }
+ else {
+ Err(ConnectionError::NoNetwork)
+ }
+ }
+}
+
+#[test]
+fn test() {
+ let a = Instant::now() - Duration::from_secs(100);
+
+ let sleep = a + Duration::from_secs(60) - Instant::now();
+ dbg!(sleep);
+}
diff --git a/src/state.rs b/src/state.rs
index de0e9fc..0aaeeac 100644
--- a/src/state.rs
+++ b/src/state.rs
@@ -1,7 +1,6 @@
use std::collections::{BTreeMap, BTreeSet};
use async_channel::Receiver;
-use async_mutex::Mutex;
use crate::{
available_packet_ids::AvailablePacketIds,
diff --git a/src/tests/handler_tests.rs b/src/tests/handler_tests.rs
deleted file mode 100644
index 8b13789..0000000
--- a/src/tests/handler_tests.rs
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/tests/mod.rs b/src/tests/mod.rs
index 6e5737d..0525f44 100644
--- a/src/tests/mod.rs
+++ b/src/tests/mod.rs
@@ -1,4 +1,2 @@
mod connection_tests;
-mod handler_tests;
pub mod resources;
-pub mod stages;
diff --git a/src/tests/resources/test_packets.rs b/src/tests/resources/test_packets.rs
index dabf117..91e8ef3 100644
--- a/src/tests/resources/test_packets.rs
+++ b/src/tests/resources/test_packets.rs
@@ -1,14 +1,9 @@
-#[allow(dead_code)]
use bytes::Bytes;
use crate::packets::{
- Disconnect, DisconnectProperties,
- Packet,
- PubAck, PubAckProperties,
- Publish, PublishProperties,
- Subscribe, Subscription,
- QoS,
- reason_codes::{PubAckReasonCode, DisconnectReasonCode},
+ reason_codes::{DisconnectReasonCode, PubAckReasonCode},
+ Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, Publish, PublishProperties,
+ QoS, Subscribe, Subscription,
};
pub fn publish_packets() -> Vec {
diff --git a/src/tests/stages.rs b/src/tests/stages.rs
deleted file mode 100644
index 44f729f..0000000
--- a/src/tests/stages.rs
+++ /dev/null
@@ -1,69 +0,0 @@
-use tracing::{trace, warn};
-
-use crate::{event_handler::AsyncEventHandler, packets::Packet};
-
-pub struct Nop { }
-impl AsyncEventHandler for Nop {
- fn handle<'a>(
- &'a mut self,
- event: &'a Packet,
- ) -> impl core::future::Future + Send + 'a {
- async move {
- // warn!("{:?}", event)
- }
- }
-}
-
-
-pub mod qos_2 {
- use crate::{
- client::AsyncClient,
- event_handler::AsyncEventHandler, packets::{Packet, PacketType},
- // packets::{Packet, PacketType},
- };
-
- pub struct TestPubQoS2 {
- stage: StagePubQoS2,
- client: AsyncClient,
- }
- pub enum StagePubQoS2 {
- ConnAck,
- PubRec,
- PubComp,
- Done,
- }
- impl TestPubQoS2 {
- #[allow(dead_code)]
- pub fn new(client: AsyncClient) -> Self {
- TestPubQoS2 {
- stage: StagePubQoS2::ConnAck,
- client,
- }
- }
- }
- impl AsyncEventHandler for TestPubQoS2 {
- fn handle<'a>(
- &'a mut self,
- event: &'a Packet,
- ) -> impl core::future::Future + Send + 'a {
- async move {
- match self.stage {
- StagePubQoS2::ConnAck => {
- assert_eq!(event.packet_type(), PacketType::ConnAck);
- self.stage = StagePubQoS2::PubRec;
- }
- StagePubQoS2::PubRec => {
- assert_eq!(event.packet_type(), PacketType::PubRec);
- self.stage = StagePubQoS2::PubComp;
- }
- StagePubQoS2::PubComp => {
- assert_eq!(event.packet_type(), PacketType::PubComp);
- self.stage = StagePubQoS2::Done;
- self.client.disconnect().await.unwrap();
- }
- StagePubQoS2::Done => (),
- }
- }
- }
- }
-}
diff --git a/src/tokio_network.rs b/src/tokio_network.rs
new file mode 100644
index 0000000..79585d8
--- /dev/null
+++ b/src/tokio_network.rs
@@ -0,0 +1,156 @@
+use async_channel::{Receiver, Sender};
+
+use tokio::io::{AsyncReadExt, AsyncWriteExt};
+
+use std::time::{Duration, Instant};
+
+use crate::connect_options::ConnectOptions;
+use crate::connections::tokio_stream::TokioStream;
+use crate::error::ConnectionError;
+use crate::packets::error::ReadBytes;
+use crate::packets::{Packet, PacketType};
+use crate::NetworkStatus;
+
+pub struct TokioNetwork {
+ network: Option>,
+
+ /// Options of the current mqtt connection
+ options: ConnectOptions,
+
+ last_network_action: Instant,
+ await_pingresp: Option,
+ perform_keep_alive: bool,
+
+ network_to_handler_s: Sender,
+
+ to_network_r: Receiver,
+}
+
+impl TokioNetwork
+where
+ S: AsyncReadExt + AsyncWriteExt + Sized + Unpin,
+{
+ pub fn new(
+ options: ConnectOptions,
+ network_to_handler_s: Sender,
+ to_network_r: Receiver,
+ ) -> Self {
+ Self {
+ network: None,
+
+ options,
+
+ last_network_action: Instant::now(),
+ await_pingresp: None,
+ perform_keep_alive: true,
+
+ network_to_handler_s,
+ to_network_r,
+ }
+ }
+
+ pub fn reset(&mut self) {
+ self.network = None;
+ }
+
+ pub async fn connect(&mut self, stream: S) -> Result<(), ConnectionError> {
+ let (network, connack) = TokioStream::connect(&self.options, stream).await?;
+
+ self.network = Some(network);
+ self.network_to_handler_s.send(connack).await?;
+ self.last_network_action = Instant::now();
+ if self.options.keep_alive_interval_s == 0 {
+ self.perform_keep_alive = false;
+ }
+ Ok(())
+ }
+
+ pub async fn run(&mut self) -> Result {
+ if self.network.is_none() {
+ return Err(ConnectionError::NoNetwork);
+ }
+
+ match self.select().await {
+ Ok(NetworkStatus::Active) => Ok(NetworkStatus::Active),
+ otherwise => {
+ self.reset();
+ otherwise
+ }
+ }
+ }
+
+ async fn select(&mut self) -> Result {
+ let TokioNetwork {
+ network,
+ options: _,
+ last_network_action,
+ await_pingresp,
+ perform_keep_alive,
+ network_to_handler_s,
+ to_network_r,
+ } = self;
+
+ let sleep;
+ if let Some(instant) = await_pingresp {
+ sleep =
+ *instant + Duration::from_secs(self.options.keep_alive_interval_s) - Instant::now();
+ }
+ else {
+ sleep = *last_network_action + Duration::from_secs(self.options.keep_alive_interval_s)
+ - Instant::now();
+ }
+
+ if let Some(stream) = network {
+ loop {
+ tokio::select! {
+ _ = stream.read_bytes() => {
+ match stream.parse_messages(network_to_handler_s).await {
+ Err(ReadBytes::Err(err)) => return Err(err),
+ Err(ReadBytes::InsufficientBytes(_)) => continue,
+ Ok(Some(PacketType::PingResp)) => {
+ *await_pingresp = None;
+ return Ok(NetworkStatus::Active)
+ },
+ Ok(Some(PacketType::Disconnect)) => {
+ return Ok(NetworkStatus::IncomingDisconnect)
+ },
+ Ok(_) => {
+ return Ok(NetworkStatus::Active)
+ }
+ };
+ },
+ outgoing = to_network_r.recv() => {
+ let packet = outgoing?;
+ stream.write(&packet).await?;
+ *last_network_action = Instant::now();
+ if packet.packet_type() == PacketType::Disconnect{
+ return Ok(NetworkStatus::OutgoingDisconnect);
+ }
+ return Ok(NetworkStatus::Active);
+ },
+ _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => {
+ let packet = Packet::PingReq;
+ stream.write(&packet).await?;
+ *last_network_action = Instant::now();
+ *await_pingresp = Some(Instant::now());
+ return Ok(NetworkStatus::Active);
+ },
+ _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => {
+ return Ok(NetworkStatus::NoPingResp);
+ }
+ }
+ }
+ }
+ else {
+ Err(ConnectionError::NoNetwork)
+ }
+ }
+}
+
+#[test]
+fn test() {
+ let a = Instant::now() - Duration::from_secs(100);
+
+ let sleep = a + Duration::from_secs(60) - Instant::now();
+ dbg!(sleep);
+}
diff --git a/src/util/mod.rs b/src/util/mod.rs
index e5399c5..52ab346 100644
--- a/src/util/mod.rs
+++ b/src/util/mod.rs
@@ -1,2 +1,3 @@
pub mod constants;
pub mod timeout;
+pub mod tls;
diff --git a/src/util/timeout.rs b/src/util/timeout.rs
index 02bd5b4..69f175a 100644
--- a/src/util/timeout.rs
+++ b/src/util/timeout.rs
@@ -1,29 +1,29 @@
-use std::fmt::Display;
+// use std::fmt::Display;
-use futures_concurrency::future::Race;
+// use futures_concurrency::future::Race;
-#[derive(Debug, Clone, Copy)]
-pub struct Timeout(());
+// #[derive(Debug, Clone, Copy)]
+// pub struct Timeout(());
-impl Display for Timeout {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- write!(f, "Timeout")
- }
-}
+// impl Display for Timeout {
+// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+// write!(f, "Timeout")
+// }
+// }
-impl std::error::Error for Timeout {}
+// impl std::error::Error for Timeout {}
-pub async fn timeout(
- fut: impl core::future::Future,
- delay_seconds: u64,
-) -> Result {
- (async { Ok(fut.await) }, async {
- #[cfg(feature = "smol")]
- smol::Timer::after(std::time::Duration::from_secs(delay_seconds)).await;
- #[cfg(feature = "tokio")]
- tokio::time::sleep(tokio::time::Duration::from_secs(delay_seconds)).await;
- Err(Timeout(()))
- })
- .race()
- .await
-}
+// pub async fn timeout(
+// fut: impl core::future::Future,
+// delay_seconds: u64,
+// ) -> Result {
+// (async { Ok(fut.await) }, async {
+// #[cfg(feature = "smol")]
+// smol::Timer::after(std::time::Duration::from_secs(delay_seconds)).await;
+// #[cfg(feature = "tokio")]
+// tokio::time::sleep(tokio::time::Duration::from_secs(delay_seconds)).await;
+// Err(Timeout(()))
+// })
+// .race()
+// .await
+// }
diff --git a/src/util/tls.rs b/src/util/tls.rs
new file mode 100644
index 0000000..39c4906
--- /dev/null
+++ b/src/util/tls.rs
@@ -0,0 +1,75 @@
+#[cfg(test)]
+pub mod tests {
+ use std::{
+ io::{BufReader, Cursor},
+ sync::Arc,
+ };
+
+ use rustls::{Certificate, ClientConfig, Error, OwnedTrustAnchor, RootCertStore};
+
+ #[derive(Debug, Clone)]
+ pub enum PrivateKey {
+ RSA(Vec),
+ ECC(Vec),
+ }
+
+ pub fn simple_rust_tls(
+ ca: Vec,
+ alpn: Option>>,
+ client_auth: Option<(Vec, PrivateKey)>,
+ ) -> Result, Error> {
+ let mut root_cert_store = RootCertStore::empty();
+
+ let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca))).unwrap();
+
+ let trust_anchors = ca_certs.iter().map_while(|cert| {
+ if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) {
+ Some(OwnedTrustAnchor::from_subject_spki_name_constraints(
+ ta.subject,
+ ta.spki,
+ ta.name_constraints,
+ ))
+ }
+ else {
+ None
+ }
+ });
+ root_cert_store.add_server_trust_anchors(trust_anchors);
+
+ assert!(!root_cert_store.is_empty());
+
+ let config = ClientConfig::builder()
+ .with_safe_defaults()
+ .with_root_certificates(root_cert_store);
+
+ let mut config = match client_auth {
+ Some((client_cert_info, client_private_info)) => {
+ let read_private_keys = match client_private_info {
+ PrivateKey::RSA(rsa) => {
+ rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa)))
+ }
+ PrivateKey::ECC(ecc) => {
+ rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc)))
+ }
+ }
+ .unwrap();
+
+ let key = read_private_keys.into_iter().next().unwrap();
+
+ let client_certs =
+ rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info)))
+ .unwrap();
+ let client_cert_chain = client_certs.into_iter().map(Certificate).collect();
+
+ config.with_single_cert(client_cert_chain, rustls::PrivateKey(key))?
+ }
+ None => config.with_no_client_auth(),
+ };
+
+ if let Some(alpn) = alpn {
+ config.alpn_protocols.extend(alpn)
+ }
+
+ Ok(Arc::new(config))
+ }
+}