From eca25c006c069b93cae89fb2adee5a67974cb3c2 Mon Sep 17 00:00:00 2001 From: gunnarpieter <13799935+gunnarpieter@users.noreply.github.com> Date: Fri, 12 Jul 2024 23:38:53 +0200 Subject: [PATCH 01/45] Not sure if this is the way to go Especially if the reader is not buffered this will cause a lot of sys read calls. If the IO is buffered your ok. The current reading into a buffer and then attempting to deserialize from it is a bit tideious. --- mqrstt/src/packets/connack.rs | 6 +- mqrstt/src/packets/connect.rs | 6 +- mqrstt/src/packets/disconnect.rs | 4 +- mqrstt/src/packets/error.rs | 10 +- mqrstt/src/packets/macros.rs | 28 ++++ mqrstt/src/packets/mod.rs | 256 +++++++++++++++++------------ mqrstt/src/packets/mqtt_traits.rs | 6 +- mqrstt/src/packets/puback.rs | 4 +- mqrstt/src/packets/pubcomp.rs | 4 +- mqrstt/src/packets/publish.rs | 20 +-- mqrstt/src/packets/pubrec.rs | 4 +- mqrstt/src/packets/pubrel.rs | 2 +- mqrstt/src/packets/reason_codes.rs | 18 +- mqrstt/src/packets/suback.rs | 2 +- mqrstt/src/packets/subscribe.rs | 4 +- mqrstt/src/packets/unsuback.rs | 2 +- mqrstt/src/packets/unsubscribe.rs | 2 +- 17 files changed, 232 insertions(+), 146 deletions(-) create mode 100644 mqrstt/src/packets/macros.rs diff --git a/mqrstt/src/packets/connack.rs b/mqrstt/src/packets/connack.rs index fb198ab..bf7c7e9 100644 --- a/mqrstt/src/packets/connack.rs +++ b/mqrstt/src/packets/connack.rs @@ -23,7 +23,7 @@ pub struct ConnAck { impl VariableHeaderRead for ConnAck { fn read(_: u8, header_len: usize, mut buf: bytes::Bytes) -> Result { if header_len > buf.len() { - return Err(DeserializeError::InsufficientData("ConnAck".to_string(), buf.len(), header_len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), header_len)); } let connack_flags = ConnAckFlags::read(&mut buf)?; @@ -137,7 +137,7 @@ impl MqttRead for ConnAckProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("ConnAckProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -426,7 +426,7 @@ pub struct ConnAckFlags { impl MqttRead for ConnAckFlags { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("ConnAckFlags".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } let byte = buf.get_u8(); diff --git a/mqrstt/src/packets/connect.rs b/mqrstt/src/packets/connect.rs index 2ada053..f1cf0ed 100644 --- a/mqrstt/src/packets/connect.rs +++ b/mqrstt/src/packets/connect.rs @@ -253,7 +253,7 @@ impl Default for ConnectFlags { impl MqttRead for ConnectFlags { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("ConnectFlags".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } let byte = buf.get_u8(); @@ -373,7 +373,7 @@ impl MqttRead for ConnectProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("ConnectProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -566,7 +566,7 @@ impl MqttRead for LastWillProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("LastWillProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut property_data = buf.split_to(len); diff --git a/mqrstt/src/packets/disconnect.rs b/mqrstt/src/packets/disconnect.rs index d2307cb..19fd103 100644 --- a/mqrstt/src/packets/disconnect.rs +++ b/mqrstt/src/packets/disconnect.rs @@ -66,13 +66,13 @@ impl MqttRead for DisconnectProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("DisconnectProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut property_data = buf.split_to(len); loop { - match PropertyType::from_u8(u8::read(&mut property_data)?)? { + match PropertyType::try_from(u8::read(&mut property_data)?)? { PropertyType::SessionExpiryInterval => { if properties.session_expiry_interval.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index c54689c..63a40b0 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -4,6 +4,14 @@ use thiserror::Error; use super::{PacketType, PropertyType}; +#[derive(Error, Debug)] +pub enum ReadError{ + #[error("{0}")] + DeserializeError(#[from] DeserializeError), + #[error("{0}")] + IoError(#[from] std::io::Error), +} + #[derive(Error, Clone, Debug)] pub enum DeserializeError { #[error("Malformed packet: {0}")] @@ -22,7 +30,7 @@ pub enum DeserializeError { UnknownProtocolVersion, #[error("There is insufficient for {0} data ({1}) to take {2} bytes")] - InsufficientData(String, usize, usize), + InsufficientData(&'static str, usize, usize), #[error("There is insufficient to read the protocol version.")] InsufficientDataForProtocolVersion, diff --git a/mqrstt/src/packets/macros.rs b/mqrstt/src/packets/macros.rs new file mode 100644 index 0000000..023b4f5 --- /dev/null +++ b/mqrstt/src/packets/macros.rs @@ -0,0 +1,28 @@ +macro_rules! MqttAsyncRead { + ( $name:ident, $id:expr, { $($fname:ident : $ftype:ty),* } ) => { + #[derive(Codec, Debug, Eq, PartialEq, Clone)] + pub struct $name{ + $( + pub $fname: $ftype, + )* + } + impl crate::packets::mqtt_traits::MqttAsyncRead for $name where T: tokio::io::AsyncReadExt{ + fn name(&self) -> &'static str{ + stringify!($name) + } + + } + }; + ($name:ident, $id:expr) => { + #[derive(Codec, Debug, Eq, PartialEq, Clone)] + pub struct $name{} + impl RequestTrait for $name{ + fn name(&self) -> &'static str{ + stringify!($name) + } + fn get_id(&self) -> u32 { + return $id; + } + } + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 8675d1b..5f7c6b2 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -2,6 +2,7 @@ pub mod error; pub mod mqtt_traits; pub mod reason_codes; +mod macros; mod auth; mod connack; mod connect; @@ -20,6 +21,8 @@ pub use auth::*; pub use connack::*; pub use connect::*; pub use disconnect::*; +use error::ReadError; +use mqtt_traits::MqttAsyncRead; pub use puback::*; pub use pubcomp::*; pub use publish::*; @@ -65,6 +68,18 @@ impl MqttRead for ProtocolVersion { } } +impl MqttAsyncRead for ProtocolVersion where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + match buf.read_u8().await { + Ok(5) => Ok(ProtocolVersion::V5), + Ok(4) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), + Ok(3) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), + Ok(_) => Err(ReadError::DeserializeError(DeserializeError::UnknownProtocolVersion)), + Err(e) => Err(ReadError::IoError(e)), + } + } +} + /// Quality of service #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum QoS { @@ -95,7 +110,7 @@ impl MqttRead for QoS { #[inline] fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("QoS".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -107,6 +122,18 @@ impl MqttRead for QoS { } } +impl MqttAsyncRead for QoS where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + match buf.read_u8().await { + Ok(0) => Ok(QoS::AtMostOnce), + Ok(1) => Ok(QoS::AtLeastOnce), + Ok(2) => Ok(QoS::ExactlyOnce), + Ok(q) => Err(ReadError::DeserializeError(DeserializeError::UnknownQoS(q))), + Err(e) => Err(ReadError::IoError(e)), + } + } +} + impl MqttWrite for QoS { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -132,6 +159,16 @@ impl MqttRead for Box { } } +impl MqttAsyncRead for Box where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + let content = Bytes::async_read(buf).await?; + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok(s.into()), + Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), + } + } +} + impl MqttWrite for Box { #[inline(always)] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -174,6 +211,17 @@ impl MqttRead for String { } } +impl MqttAsyncRead for String where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + let content = Bytes::async_read(buf).await?; + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok(s), + Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), + } + } +} + + impl MqttWrite for String { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -200,13 +248,23 @@ impl MqttRead for Bytes { let len = buf.get_u16() as usize; if len > buf.len() { - return Err(DeserializeError::InsufficientData("Bytes".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } Ok(buf.split_to(len)) } } +impl MqttAsyncRead for Bytes where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + let size = buf.read_u16().await?; + let mut data = BytesMut::with_capacity(size as usize); + buf.read_exact(&mut data).await?; + Ok(data.into()) + } +} + + impl MqttWrite for Bytes { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -227,7 +285,7 @@ impl WireLength for Bytes { impl MqttRead for bool { fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("bool".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -238,6 +296,16 @@ impl MqttRead for bool { } } +impl MqttAsyncRead for bool where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + match buf.read_u8().await? { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)), + } + } +} + impl MqttWrite for bool { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -255,22 +323,34 @@ impl MqttRead for u8 { #[inline] fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("u8".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } Ok(buf.get_u8()) } } +impl MqttAsyncRead for u8 where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + Ok(buf.read_u8().await?) + } +} + impl MqttRead for u16 { #[inline] fn read(buf: &mut Bytes) -> Result { if buf.len() < 2 { - return Err(DeserializeError::InsufficientData("u16".to_string(), buf.len(), 2)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); } Ok(buf.get_u16()) } } +impl MqttAsyncRead for u16 where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + Ok(buf.read_u16().await?) + } +} + impl MqttWrite for u16 { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -283,11 +363,16 @@ impl MqttRead for u32 { #[inline] fn read(buf: &mut Bytes) -> Result { if buf.len() < 4 { - return Err(DeserializeError::InsufficientData("u32".to_string(), buf.len(), 4)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); } Ok(buf.get_u32()) } } +impl MqttAsyncRead for u32 where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + Ok(buf.read_u32().await?) + } +} impl MqttWrite for u32 { fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -399,13 +484,11 @@ pub enum PropertyType { SharedSubscriptionAvailable = 42, } -impl MqttRead for PropertyType { - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PropertyType".to_string(), 0, 1)); - } +impl TryFrom for PropertyType { + type Error = DeserializeError; - match buf.get_u8() { + fn try_from(value: u8) -> Result { + match value { 1 => Ok(Self::PayloadFormatIndicator), 2 => Ok(Self::MessageExpiryInterval), 3 => Ok(Self::ContentType), @@ -438,109 +521,72 @@ impl MqttRead for PropertyType { } } -impl MqttWrite for PropertyType { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - let val = match self { - Self::PayloadFormatIndicator => 1, - Self::MessageExpiryInterval => 2, - Self::ContentType => 3, - Self::ResponseTopic => 8, - Self::CorrelationData => 9, - Self::SubscriptionIdentifier => 11, - Self::SessionExpiryInterval => 17, - Self::AssignedClientIdentifier => 18, - Self::ServerKeepAlive => 19, - Self::AuthenticationMethod => 21, - Self::AuthenticationData => 22, - Self::RequestProblemInformation => 23, - Self::WillDelayInterval => 24, - Self::RequestResponseInformation => 25, - Self::ResponseInformation => 26, - Self::ServerReference => 28, - Self::ReasonString => 31, - Self::ReceiveMaximum => 33, - Self::TopicAliasMaximum => 34, - Self::TopicAlias => 35, - Self::MaximumQos => 36, - Self::RetainAvailable => 37, - Self::UserProperty => 38, - Self::MaximumPacketSize => 39, - Self::WildcardSubscriptionAvailable => 40, - Self::SubscriptionIdentifierAvailable => 41, - Self::SharedSubscriptionAvailable => 42, - }; +impl From<&PropertyType> for u8 { + fn from(value: &PropertyType) -> Self { + match value { + PropertyType::PayloadFormatIndicator => 1, + PropertyType::MessageExpiryInterval => 2, + PropertyType::ContentType => 3, + PropertyType::ResponseTopic => 8, + PropertyType::CorrelationData => 9, + PropertyType::SubscriptionIdentifier => 11, + PropertyType::SessionExpiryInterval => 17, + PropertyType::AssignedClientIdentifier => 18, + PropertyType::ServerKeepAlive => 19, + PropertyType::AuthenticationMethod => 21, + PropertyType::AuthenticationData => 22, + PropertyType::RequestProblemInformation => 23, + PropertyType::WillDelayInterval => 24, + PropertyType::RequestResponseInformation => 25, + PropertyType::ResponseInformation => 26, + PropertyType::ServerReference => 28, + PropertyType::ReasonString => 31, + PropertyType::ReceiveMaximum => 33, + PropertyType::TopicAliasMaximum => 34, + PropertyType::TopicAlias => 35, + PropertyType::MaximumQos => 36, + PropertyType::RetainAvailable => 37, + PropertyType::UserProperty => 38, + PropertyType::MaximumPacketSize => 39, + PropertyType::WildcardSubscriptionAvailable => 40, + PropertyType::SubscriptionIdentifierAvailable => 41, + PropertyType::SharedSubscriptionAvailable => 42, + } + } +} - buf.put_u8(val); - Ok(()) +impl From for u8 { + fn from(value: PropertyType) -> Self { + value.into() } } -impl PropertyType { - pub fn from_u8(value: u8) -> Result { - match value { - 1 => Ok(Self::PayloadFormatIndicator), - 2 => Ok(Self::MessageExpiryInterval), - 3 => Ok(Self::ContentType), - 8 => Ok(Self::ResponseTopic), - 9 => Ok(Self::CorrelationData), - 11 => Ok(Self::SubscriptionIdentifier), - 17 => Ok(Self::SessionExpiryInterval), - 18 => Ok(Self::AssignedClientIdentifier), - 19 => Ok(Self::ServerKeepAlive), - 21 => Ok(Self::AuthenticationMethod), - 22 => Ok(Self::AuthenticationData), - 23 => Ok(Self::RequestProblemInformation), - 24 => Ok(Self::WillDelayInterval), - 25 => Ok(Self::RequestResponseInformation), - 26 => Ok(Self::ResponseInformation), - 28 => Ok(Self::ServerReference), - 31 => Ok(Self::ReasonString), - 33 => Ok(Self::ReceiveMaximum), - 34 => Ok(Self::TopicAliasMaximum), - 35 => Ok(Self::TopicAlias), - 36 => Ok(Self::MaximumQos), - 37 => Ok(Self::RetainAvailable), - 38 => Ok(Self::UserProperty), - 39 => Ok(Self::MaximumPacketSize), - 40 => Ok(Self::WildcardSubscriptionAvailable), - 41 => Ok(Self::SubscriptionIdentifierAvailable), - 42 => Ok(Self::SharedSubscriptionAvailable), - _ => Err("Unkown property type".to_string()), +impl MqttRead for PropertyType { + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } + + buf.get_u8().try_into() } - pub fn to_u8(self) -> u8 { - match self { - Self::PayloadFormatIndicator => 1, - Self::MessageExpiryInterval => 2, - Self::ContentType => 3, - Self::ResponseTopic => 8, - Self::CorrelationData => 9, - Self::SubscriptionIdentifier => 11, - Self::SessionExpiryInterval => 17, - Self::AssignedClientIdentifier => 18, - Self::ServerKeepAlive => 19, - Self::AuthenticationMethod => 21, - Self::AuthenticationData => 22, - Self::RequestProblemInformation => 23, - Self::WillDelayInterval => 24, - Self::RequestResponseInformation => 25, - Self::ResponseInformation => 26, - Self::ServerReference => 28, - Self::ReasonString => 31, - Self::ReceiveMaximum => 33, - Self::TopicAliasMaximum => 34, - Self::TopicAlias => 35, - Self::MaximumQos => 36, - Self::RetainAvailable => 37, - Self::UserProperty => 38, - Self::MaximumPacketSize => 39, - Self::WildcardSubscriptionAvailable => 40, - Self::SubscriptionIdentifierAvailable => 41, - Self::SharedSubscriptionAvailable => 42, +} + +impl MqttAsyncRead for PropertyType where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result { + match buf.read_u8().await { + Ok(t) => Ok(t.try_into()?), + Err(e) => Err(ReadError::IoError(e)), } } } +impl MqttWrite for PropertyType { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u8(self.into()); + Ok(()) + } +} + // ==================== Packets ==================== #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/mqrstt/src/packets/mqtt_traits.rs b/mqrstt/src/packets/mqtt_traits.rs index 8f6122f..35fb6fd 100644 --- a/mqrstt/src/packets/mqtt_traits.rs +++ b/mqrstt/src/packets/mqtt_traits.rs @@ -1,6 +1,6 @@ use bytes::{Bytes, BytesMut}; -use super::error::{DeserializeError, SerializeError}; +use super::error::{DeserializeError, ReadError, SerializeError}; pub trait VariableHeaderRead: Sized { fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; @@ -17,6 +17,10 @@ pub trait WireLength { pub trait MqttRead: Sized { fn read(buf: &mut Bytes) -> Result; } +pub trait MqttAsyncRead: Sized where T: tokio::io::AsyncReadExt { + async fn async_read(buf: &mut T) -> Result; +} + pub trait MqttWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError>; diff --git a/mqrstt/src/packets/puback.rs b/mqrstt/src/packets/puback.rs index 4b0970c..bfba6b8 100644 --- a/mqrstt/src/packets/puback.rs +++ b/mqrstt/src/packets/puback.rs @@ -27,7 +27,7 @@ impl VariableHeaderRead for PubAck { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData("PubAck".to_string(), buf.len(), 4)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -94,7 +94,7 @@ impl MqttRead for PubAckProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubAckProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties = PubAckProperties::default(); diff --git a/mqrstt/src/packets/pubcomp.rs b/mqrstt/src/packets/pubcomp.rs index b814da2..69ba679 100644 --- a/mqrstt/src/packets/pubcomp.rs +++ b/mqrstt/src/packets/pubcomp.rs @@ -37,7 +37,7 @@ impl VariableHeaderRead for PubComp { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData("PubComp".to_string(), buf.len(), 4)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -100,7 +100,7 @@ impl MqttRead for PubCompProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubCompProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties = PubCompProperties::default(); diff --git a/mqrstt/src/packets/publish.rs b/mqrstt/src/packets/publish.rs index 6f3baeb..d744a9b 100644 --- a/mqrstt/src/packets/publish.rs +++ b/mqrstt/src/packets/publish.rs @@ -164,7 +164,7 @@ impl MqttRead for PublishProperties { if len == 0 { return Ok(Self::default()); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("PublishProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut property_data = buf.split_to(len); @@ -172,7 +172,7 @@ impl MqttRead for PublishProperties { let mut properties = Self::default(); loop { - match PropertyType::from_u8(u8::read(&mut property_data)?)? { + match PropertyType::try_from(u8::read(&mut property_data)?)? { PropertyType::PayloadFormatIndicator => { if properties.payload_format_indicator.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); @@ -229,36 +229,36 @@ impl MqttWrite for PublishProperties { write_variable_integer(buf, self.wire_len())?; if let Some(payload_format_indicator) = self.payload_format_indicator { - buf.put_u8(PropertyType::PayloadFormatIndicator.to_u8()); + buf.put_u8(PropertyType::PayloadFormatIndicator.into()); buf.put_u8(payload_format_indicator); } if let Some(message_expiry_interval) = self.message_expiry_interval { - buf.put_u8(PropertyType::MessageExpiryInterval.to_u8()); + buf.put_u8(PropertyType::MessageExpiryInterval.into()); buf.put_u32(message_expiry_interval); } if let Some(topic_alias) = self.topic_alias { - buf.put_u8(PropertyType::TopicAlias.to_u8()); + buf.put_u8(PropertyType::TopicAlias.into()); buf.put_u16(topic_alias); } if let Some(response_topic) = &self.response_topic { - buf.put_u8(PropertyType::ResponseTopic.to_u8()); + buf.put_u8(PropertyType::ResponseTopic.into()); response_topic.as_ref().write(buf)?; } if let Some(correlation_data) = &self.correlation_data { - buf.put_u8(PropertyType::CorrelationData.to_u8()); + buf.put_u8(PropertyType::CorrelationData.into()); correlation_data.write(buf)?; } for sub_id in &self.subscription_identifier { - buf.put_u8(PropertyType::SubscriptionIdentifier.to_u8()); + buf.put_u8(PropertyType::SubscriptionIdentifier.into()); write_variable_integer(buf, *sub_id)?; } for (key, val) in &self.user_properties { - buf.put_u8(PropertyType::UserProperty.to_u8()); + buf.put_u8(PropertyType::UserProperty.into()); key.write(buf)?; val.write(buf)?; } if let Some(content_type) = &self.content_type { - buf.put_u8(PropertyType::ContentType.to_u8()); + buf.put_u8(PropertyType::ContentType.into()); content_type.write(buf)?; } diff --git a/mqrstt/src/packets/pubrec.rs b/mqrstt/src/packets/pubrec.rs index 15289c8..2351b4e 100644 --- a/mqrstt/src/packets/pubrec.rs +++ b/mqrstt/src/packets/pubrec.rs @@ -36,7 +36,7 @@ impl VariableHeaderRead for PubRec { } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData("PubRec".to_string(), buf.len(), 4)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); } let packet_identifier = u16::read(&mut buf)?; @@ -99,7 +99,7 @@ impl MqttRead for PubRecProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubRecProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties = PubRecProperties::default(); diff --git a/mqrstt/src/packets/pubrel.rs b/mqrstt/src/packets/pubrel.rs index 79039ac..aedbc9b 100644 --- a/mqrstt/src/packets/pubrel.rs +++ b/mqrstt/src/packets/pubrel.rs @@ -98,7 +98,7 @@ impl MqttRead for PubRelProperties { return Ok(Self::default()); } if buf.len() < len { - return Err(DeserializeError::InsufficientData("PubRelProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties = PubRelProperties::default(); diff --git a/mqrstt/src/packets/reason_codes.rs b/mqrstt/src/packets/reason_codes.rs index a940562..604189e 100644 --- a/mqrstt/src/packets/reason_codes.rs +++ b/mqrstt/src/packets/reason_codes.rs @@ -36,7 +36,7 @@ pub enum ConnAckReasonCode { impl MqttRead for ConnAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("ConAckReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -110,7 +110,7 @@ pub enum AuthReasonCode { impl MqttRead for AuthReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("AuthReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -173,7 +173,7 @@ pub enum DisconnectReasonCode { impl MqttRead for DisconnectReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("DisconnectReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -267,7 +267,7 @@ pub enum PubAckReasonCode { impl MqttRead for PubAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubAckReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -314,7 +314,7 @@ pub enum PubCompReasonCode { impl MqttRead for PubCompReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubCompReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -353,7 +353,7 @@ pub enum PubRecReasonCode { impl MqttRead for PubRecReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubRecReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -399,7 +399,7 @@ pub enum PubRelReasonCode { impl MqttRead for PubRelReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("PubRelReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -441,7 +441,7 @@ pub enum SubAckReasonCode { impl MqttRead for SubAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("SubAckReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { @@ -498,7 +498,7 @@ pub enum UnsubAckReasonCode { impl MqttRead for UnsubAckReasonCode { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("UnsubAckReasonCode".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } match buf.get_u8() { diff --git a/mqrstt/src/packets/suback.rs b/mqrstt/src/packets/suback.rs index 3f8caa2..3ea0691 100644 --- a/mqrstt/src/packets/suback.rs +++ b/mqrstt/src/packets/suback.rs @@ -74,7 +74,7 @@ impl MqttRead for SubAckProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("SubAckProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties_data = buf.split_to(len); diff --git a/mqrstt/src/packets/subscribe.rs b/mqrstt/src/packets/subscribe.rs index 80d994f..2c44d03 100644 --- a/mqrstt/src/packets/subscribe.rs +++ b/mqrstt/src/packets/subscribe.rs @@ -108,7 +108,7 @@ impl MqttRead for SubscribeProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("SubscribeProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties_data = buf.split_to(len); @@ -189,7 +189,7 @@ impl Default for SubscriptionOptions { impl MqttRead for SubscriptionOptions { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData("SubscriptionOptions".to_string(), 0, 1)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } let byte = buf.get_u8(); diff --git a/mqrstt/src/packets/unsuback.rs b/mqrstt/src/packets/unsuback.rs index da9e447..5d53378 100644 --- a/mqrstt/src/packets/unsuback.rs +++ b/mqrstt/src/packets/unsuback.rs @@ -63,7 +63,7 @@ impl MqttRead for UnsubAckProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("UnsubAckProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties_data = buf.split_to(len); diff --git a/mqrstt/src/packets/unsubscribe.rs b/mqrstt/src/packets/unsubscribe.rs index 19e08f4..29f783a 100644 --- a/mqrstt/src/packets/unsubscribe.rs +++ b/mqrstt/src/packets/unsubscribe.rs @@ -98,7 +98,7 @@ impl MqttRead for UnsubscribeProperties { if len == 0 { return Ok(properties); } else if buf.len() < len { - return Err(DeserializeError::InsufficientData("UnsubscribeProperties".to_string(), buf.len(), len)); + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } let mut properties_data = buf.split_to(len); From 30d435fd16b49cd8e416bcd70e7a2e8e436057c0 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Sun, 17 Nov 2024 01:37:38 +0100 Subject: [PATCH 02/45] WIP: macros for reason codes and properties for wire length and async read + write --- mqrstt/src/packets/auth.rs | 6 +- mqrstt/src/packets/connack.rs | 367 +++++++---- mqrstt/src/packets/connect.rs | 8 +- mqrstt/src/packets/disconnect.rs | 6 +- mqrstt/src/packets/error.rs | 5 +- mqrstt/src/packets/macros.rs | 28 - mqrstt/src/packets/macros/mod.rs | 6 + .../src/packets/macros/properties_macros.rs | 609 ++++++++++++++++++ .../src/packets/macros/reason_code_macros.rs | 527 +++++++++++++++ mqrstt/src/packets/mod.rs | 137 ++-- mqrstt/src/packets/mqtt_traits.rs | 18 +- mqrstt/src/packets/puback.rs | 10 +- mqrstt/src/packets/pubcomp.rs | 10 +- mqrstt/src/packets/publish.rs | 41 +- mqrstt/src/packets/pubrec.rs | 10 +- mqrstt/src/packets/pubrel.rs | 262 +++++++- mqrstt/src/packets/reason_codes.rs | 496 ++++++++------ mqrstt/src/packets/suback.rs | 8 +- mqrstt/src/packets/subscribe.rs | 8 +- mqrstt/src/packets/unsuback.rs | 8 +- mqrstt/src/packets/unsubscribe.rs | 8 +- mqrstt/src/util/constants.rs | 6 +- 22 files changed, 2095 insertions(+), 489 deletions(-) delete mode 100644 mqrstt/src/packets/macros.rs create mode 100644 mqrstt/src/packets/macros/mod.rs create mode 100644 mqrstt/src/packets/macros/properties_macros.rs create mode 100644 mqrstt/src/packets/macros/reason_code_macros.rs diff --git a/mqrstt/src/packets/auth.rs b/mqrstt/src/packets/auth.rs index dc6ba68..0da56c0 100644 --- a/mqrstt/src/packets/auth.rs +++ b/mqrstt/src/packets/auth.rs @@ -2,7 +2,7 @@ use bytes::Bytes; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, read_variable_integer, reason_codes::AuthReasonCode, variable_integer_len, write_variable_integer, PacketType, PropertyType, @@ -14,7 +14,7 @@ pub struct Auth { pub properties: AuthProperties, } -impl VariableHeaderRead for Auth { +impl PacketRead for Auth { fn read(_: u8, _: usize, mut buf: Bytes) -> Result { let reason_code = AuthReasonCode::read(&mut buf)?; let properties = AuthProperties::read(&mut buf)?; @@ -23,7 +23,7 @@ impl VariableHeaderRead for Auth { } } -impl VariableHeaderWrite for Auth { +impl PacketWrite for Auth { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { self.reason_code.write(buf)?; self.properties.write(buf)?; diff --git a/mqrstt/src/packets/connack.rs b/mqrstt/src/packets/connack.rs index bf7c7e9..ef6e92d 100644 --- a/mqrstt/src/packets/connack.rs +++ b/mqrstt/src/packets/connack.rs @@ -1,11 +1,11 @@ use super::{ error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, read_variable_integer, reason_codes::ConnAckReasonCode, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, }; -use bytes::{Buf, BufMut, Bytes}; +use bytes::{Buf, BufMut}; #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct ConnAck { @@ -20,7 +20,7 @@ pub struct ConnAck { pub connack_properties: ConnAckProperties, } -impl VariableHeaderRead for ConnAck { +impl PacketRead for ConnAck { fn read(_: u8, header_len: usize, mut buf: bytes::Bytes) -> Result { if header_len > buf.len() { return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), header_len)); @@ -38,7 +38,27 @@ impl VariableHeaderRead for ConnAck { } } -impl VariableHeaderWrite for ConnAck { +impl PacketAsyncRead for ConnAck where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; + let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; + let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; + + Ok(( + Self { + connack_flags, + reason_code, + connack_properties, + }, + read_bytes + reason_code_read_bytes + connack_properties_read_bytes + )) + + } + } +} + +impl PacketWrite for ConnAck { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { self.connack_flags.write(buf)?; self.reason_code.write(buf)?; @@ -56,78 +76,98 @@ impl WireLength for ConnAck { } } -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ConnAckProperties { - /// 3.2.2.3.2 Session Expiry Interval - /// 17 (0x11) Byte Identifier of the Session Expiry Interval - pub session_expiry_interval: Option, - - /// 3.2.2.3.3 Receive Maximum - /// 33 (0x21) Byte, Identifier of the Receive Maximum - pub receive_maximum: Option, - - /// 3.2.2.3.4 Maximum QoS - /// 36 (0x24) Byte, Identifier of the Maximum QoS. - pub maximum_qos: Option, - - /// 3.2.2.3.5 Retain Available - /// 37 (0x25) Byte, Identifier of Retain Available. - pub retain_available: Option, - - /// 3.2.2.3.6 Maximum Packet Size - /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. - pub maximum_packet_size: Option, - - /// 3.2.2.3.7 Assigned Client Identifier - /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. - pub assigned_client_id: Option>, - - /// 3.2.2.3.8 Topic Alias Maximum - /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. - pub topic_alias_maximum: Option, - - /// 3.2.2.3.9 Reason String - /// 31 (0x1F) Byte Identifier of the Reason String. - pub reason_string: Option>, - - /// 3.2.2.3.10 User Property - /// 38 (0x26) Byte, Identifier of User Property. - pub user_properties: Vec<(Box, Box)>, - - /// 3.2.2.3.11 Wildcard Subscription Available - /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. - pub wildcards_available: Option, - - /// 3.2.2.3.12 Subscription Identifiers Available - /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. - pub subscription_ids_available: Option, - - /// 3.2.2.3.13 Shared Subscription Available - /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. - pub shared_subscription_available: Option, - - /// 3.2.2.3.14 Server Keep Alive - /// 19 (0x13) Byte, Identifier of the Server Keep Alive - pub server_keep_alive: Option, - - /// 3.2.2.3.15 Response Information - /// 26 (0x1A) Byte, Identifier of the Response Information. - pub response_info: Option>, - - /// 3.2.2.3.16 Server Reference - /// 28 (0x1C) Byte, Identifier of the Server Reference - pub server_reference: Option>, - - /// 3.2.2.3.17 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method - pub authentication_method: Option>, - - /// 3.2.2.3.18 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - // There is a small inconsistency here with authentication_data in the connect packet. - // This is Option while that type uses just Bytes. - pub authentication_data: Option, -} +super::macros::define_properties!(ConnAckProperties, + SessionExpiryInterval, + ReceiveMaximum, + MaximumQos, + RetainAvailable, + MaximumPacketSize, + AssignedClientIdentifier, + TopicAliasMaximum, + ReasonString, + UserProperty, + WildcardSubscriptionAvailable, + SubscriptionIdentifierAvailable, + SharedSubscriptionAvailable, + ServerKeepAlive, + ResponseInformation, + ServerReference, + AuthenticationMethod, + AuthenticationData +); + +// #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +// pub struct ConnAckProperties { +// /// 3.2.2.3.2 Session Expiry Interval +// /// 17 (0x11) Byte Identifier of the Session Expiry Interval +// pub session_expiry_interval: Option, + +// /// 3.2.2.3.3 Receive Maximum +// /// 33 (0x21) Byte, Identifier of the Receive Maximum +// pub receive_maximum: Option, + +// /// 3.2.2.3.4 Maximum QoS +// /// 36 (0x24) Byte, Identifier of the Maximum QoS. +// pub maximum_qos: Option, + +// /// 3.2.2.3.5 Retain Available +// /// 37 (0x25) Byte, Identifier of Retain Available. +// pub retain_available: Option, + +// /// 3.2.2.3.6 Maximum Packet Size +// /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. +// pub maximum_packet_size: Option, + +// /// 3.2.2.3.7 Assigned Client Identifier +// /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. +// pub assigned_client_id: Option>, + +// /// 3.2.2.3.8 Topic Alias Maximum +// /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. +// pub topic_alias_maximum: Option, + +// /// 3.2.2.3.9 Reason String +// /// 31 (0x1F) Byte Identifier of the Reason String. +// pub reason_string: Option>, + +// /// 3.2.2.3.10 User Property +// /// 38 (0x26) Byte, Identifier of User Property. +// pub user_properties: Vec<(Box, Box)>, + +// /// 3.2.2.3.11 Wildcard Subscription Available +// /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. +// pub wildcards_available: Option, + +// /// 3.2.2.3.12 Subscription Identifiers Available +// /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. +// pub subscription_ids_available: Option, + +// /// 3.2.2.3.13 Shared Subscription Available +// /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. +// pub shared_subscription_available: Option, + +// /// 3.2.2.3.14 Server Keep Alive +// /// 19 (0x13) Byte, Identifier of the Server Keep Alive +// pub server_keep_alive: Option, + +// /// 3.2.2.3.15 Response Information +// /// 26 (0x1A) Byte, Identifier of the Response Information. +// pub response_info: Option>, + +// /// 3.2.2.3.16 Server Reference +// /// 28 (0x1C) Byte, Identifier of the Server Reference +// pub server_reference: Option>, + +// /// 3.2.2.3.17 Authentication Method +// /// 21 (0x15) Byte, Identifier of the Authentication Method +// pub authentication_method: Option>, + +// /// 3.2.2.3.18 Authentication Data +// /// 22 (0x16) Byte, Identifier of the Authentication Data +// // There is a small inconsistency here with authentication_data in the connect packet. +// // This is Option while that type uses just Bytes. +// pub authentication_data: Option>, +// } impl MqttRead for ConnAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { @@ -240,7 +280,7 @@ impl MqttRead for ConnAckProperties { if properties.authentication_data.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); } - properties.authentication_data = Some(Bytes::read(&mut property_data)?); + properties.authentication_data = Some(Vec::::read(&mut property_data)?); } e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::ConnAck)), @@ -356,73 +396,84 @@ impl MqttWrite for ConnAckProperties { } } -impl WireLength for ConnAckProperties { - fn wire_len(&self) -> usize { - let mut len: usize = 0; - - if self.session_expiry_interval.is_some() { - len += 1 + 4; - } - if self.receive_maximum.is_some() { - len += 1 + 2; - } - if self.maximum_qos.is_some() { - len += 1 + 1; - } - if self.retain_available.is_some() { - len += 1 + 1; - } - if self.maximum_packet_size.is_some() { - len += 1 + 4; - } - if let Some(client_id) = &self.assigned_client_id { - len += 1 + client_id.wire_len(); - } - if self.topic_alias_maximum.is_some() { - len += 1 + 2; - } - if let Some(reason_string) = &self.reason_string { - len += 1 + reason_string.wire_len(); - } - for (key, value) in &self.user_properties { - len += 1; - len += key.wire_len(); - len += value.wire_len(); - } - if self.wildcards_available.is_some() { - len += 1 + 1; - } - if self.subscription_ids_available.is_some() { - len += 1 + 1; - } - if self.shared_subscription_available.is_some() { - len += 1 + 1; - } - if self.server_keep_alive.is_some() { - len += 1 + 2; - } - if let Some(response_info) = &self.response_info { - len += 1 + response_info.wire_len(); - } - if let Some(server_reference) = &self.server_reference { - len += 1 + server_reference.wire_len(); - } - if let Some(authentication_method) = &self.authentication_method { - len += 1 + authentication_method.wire_len(); - } - if self.authentication_data.is_some() && self.authentication_method.is_some() { - len += 1 + self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); - } - - len - } -} +// impl WireLength for ConnAckProperties { +// fn wire_len(&self) -> usize { +// let mut len: usize = 0; + +// if self.session_expiry_interval.is_some() { +// len += 1 + 4; +// } +// if self.receive_maximum.is_some() { +// len += 1 + 2; +// } +// if self.maximum_qos.is_some() { +// len += 1 + 1; +// } +// if self.retain_available.is_some() { +// len += 1 + 1; +// } +// if self.maximum_packet_size.is_some() { +// len += 1 + 4; +// } +// if let Some(client_id) = &self.assigned_client_id { +// len += 1 + client_id.wire_len(); +// } +// if self.topic_alias_maximum.is_some() { +// len += 1 + 2; +// } +// if let Some(reason_string) = &self.reason_string { +// len += 1 + reason_string.wire_len(); +// } +// for (key, value) in &self.user_properties { +// len += 1; +// len += key.wire_len(); +// len += value.wire_len(); +// } +// if self.wildcards_available.is_some() { +// len += 1 + 1; +// } +// if self.subscription_ids_available.is_some() { +// len += 1 + 1; +// } +// if self.shared_subscription_available.is_some() { +// len += 1 + 1; +// } +// if self.server_keep_alive.is_some() { +// len += 1 + 2; +// } +// if let Some(response_info) = &self.response_info { +// len += 1 + response_info.wire_len(); +// } +// if let Some(server_reference) = &self.server_reference { +// len += 1 + server_reference.wire_len(); +// } +// if let Some(authentication_method) = &self.authentication_method { +// len += 1 + authentication_method.wire_len(); +// } +// if self.authentication_data.is_some() && self.authentication_method.is_some() { +// len += 1 + self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); +// } + +// len +// } +// } #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct ConnAckFlags { pub session_present: bool, } +impl MqttAsyncRead for ConnAckFlags where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(stream: &mut S) -> impl std::future::Future> { + async move { + let byte = stream.read_u8().await?; + Ok((Self { + session_present: (byte & 0b00000001) == 0b00000001, + }, 1)) + } + } +} + impl MqttRead for ConnAckFlags { fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { @@ -451,11 +502,46 @@ mod tests { use crate::packets::{ connack::{ConnAck, ConnAckProperties}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, reason_codes::ConnAckReasonCode, Packet, }; + #[test] + fn test_wire_len() { + let mut buf = bytes::BytesMut::new(); + + let connack_properties = ConnAckProperties { + session_expiry_interval: Some(60), // Session expiry interval in seconds + receive_maximum: Some(20), // Maximum number of QoS 1 and QoS 2 publications that the client is willing to process concurrently + maximum_qos: Some(crate::packets::QoS::AtMostOnce), // Maximum QoS level supported by the server + retain_available: Some(true), // Whether the server supports retained messages + maximum_packet_size: Some(1024), // Maximum packet size the server is willing to accept + assigned_client_id: Some(Box::from("client-12345")), // Client identifier assigned by the server + topic_alias_maximum: Some(10), // Maximum number of topic aliases supported by the server + reason_string: Some(Box::from("Connection accepted")), // Reason string for the connection acknowledgment + user_properties: vec![(Box::from("key1"), Box::from("value1"))], // User property key-value pair + wildcards_available: Some(true), // Whether wildcard subscriptions are available + subscription_ids_available: Some(true), // Whether subscription identifiers are available + shared_subscription_available: Some(true), // Whether shared subscriptions are available + server_keep_alive: Some(120), // Server keep alive time in seconds + response_info: Some(Box::from("Response info")), // Response information + server_reference: Some(Box::from("server-reference")), // Server reference + authentication_method: Some(Box::from("auth-method")), // Authentication method + authentication_data: Some(vec![1, 2, 3, 4]), // Authentication data + }; + + let len = connack_properties.wire_len(); + // determine length of variable integer + let len_of_wire_len = crate::packets::write_variable_integer(&mut buf, len).unwrap(); + // clear buffer before writing actual properties + buf.clear(); + connack_properties.write(&mut buf).unwrap(); + + assert_eq!(len + len_of_wire_len, buf.len()); + + } + #[test] fn read_write_connack_packet() { let c = ConnAck { ..Default::default() }; @@ -516,6 +602,9 @@ mod tests { let mut buf = bytes::BytesMut::new(); + let variable_length = c1.wire_len(); + assert_eq!(variable_length, 56); + c1.write(&mut buf).unwrap(); let _buf_clone = buf.to_vec(); diff --git a/mqrstt/src/packets/connect.rs b/mqrstt/src/packets/connect.rs index f1cf0ed..c13d1aa 100644 --- a/mqrstt/src/packets/connect.rs +++ b/mqrstt/src/packets/connect.rs @@ -2,7 +2,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use super::{ error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite}, read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, ProtocolVersion, QoS, WireLength, }; @@ -95,7 +95,7 @@ impl Default for Connect { } } -impl VariableHeaderRead for Connect { +impl PacketRead for Connect { fn read(_: u8, _: usize, mut buf: Bytes) -> Result { if String::read(&mut buf)? != "MQTT" { return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_string())); @@ -136,7 +136,7 @@ impl VariableHeaderRead for Connect { } } -impl VariableHeaderWrite for Connect { +impl PacketWrite for Connect { fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { "MQTT".write(buf)?; @@ -689,7 +689,7 @@ impl WireLength for LastWillProperties { #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttWrite, VariableHeaderRead, VariableHeaderWrite}, + mqtt_traits::{MqttWrite, PacketRead, PacketWrite}, QoS, }; diff --git a/mqrstt/src/packets/disconnect.rs b/mqrstt/src/packets/disconnect.rs index 19fd103..e5c70e9 100644 --- a/mqrstt/src/packets/disconnect.rs +++ b/mqrstt/src/packets/disconnect.rs @@ -2,7 +2,7 @@ use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, read_variable_integer, reason_codes::DisconnectReasonCode, variable_integer_len, write_variable_integer, PacketType, PropertyType, @@ -14,7 +14,7 @@ pub struct Disconnect { pub properties: DisconnectProperties, } -impl VariableHeaderRead for Disconnect { +impl PacketRead for Disconnect { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let reason_code; let properties; @@ -29,7 +29,7 @@ impl VariableHeaderRead for Disconnect { Ok(Self { reason_code, properties }) } } -impl VariableHeaderWrite for Disconnect { +impl PacketWrite for Disconnect { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { self.reason_code.write(buf)?; diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index 63a40b0..cb423ed 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -31,9 +31,12 @@ pub enum DeserializeError { #[error("There is insufficient for {0} data ({1}) to take {2} bytes")] InsufficientData(&'static str, usize, usize), - + #[error("There is insufficient to read the protocol version.")] InsufficientDataForProtocolVersion, + + #[error("Read more data for the packet than indicated length")] + ReadTooMuchData(&'static str, usize, usize), #[error("Reason code {0} is not allowed for packet type {1:?}")] UnexpectedReasonCode(u8, PacketType), diff --git a/mqrstt/src/packets/macros.rs b/mqrstt/src/packets/macros.rs deleted file mode 100644 index 023b4f5..0000000 --- a/mqrstt/src/packets/macros.rs +++ /dev/null @@ -1,28 +0,0 @@ -macro_rules! MqttAsyncRead { - ( $name:ident, $id:expr, { $($fname:ident : $ftype:ty),* } ) => { - #[derive(Codec, Debug, Eq, PartialEq, Clone)] - pub struct $name{ - $( - pub $fname: $ftype, - )* - } - impl crate::packets::mqtt_traits::MqttAsyncRead for $name where T: tokio::io::AsyncReadExt{ - fn name(&self) -> &'static str{ - stringify!($name) - } - - } - }; - ($name:ident, $id:expr) => { - #[derive(Codec, Debug, Eq, PartialEq, Clone)] - pub struct $name{} - impl RequestTrait for $name{ - fn name(&self) -> &'static str{ - stringify!($name) - } - fn get_id(&self) -> u32 { - return $id; - } - } - } -} \ No newline at end of file diff --git a/mqrstt/src/packets/macros/mod.rs b/mqrstt/src/packets/macros/mod.rs new file mode 100644 index 0000000..bc734e6 --- /dev/null +++ b/mqrstt/src/packets/macros/mod.rs @@ -0,0 +1,6 @@ +mod properties_macros; +mod reason_code_macros; + + +pub(crate) use reason_code_macros::*; +pub(crate) use properties_macros::*; diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs new file mode 100644 index 0000000..2071051 --- /dev/null +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -0,0 +1,609 @@ +macro_rules! define_properties { + ($name:ident, $($prop_variant:ident),*) => { + $crate::packets::macros::properties_struct!(@ $name { $($prop_variant,)* } -> ()); + + impl $crate::packets::mqtt_traits::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let (len, length_variable_integer) = $crate::packets::read_async_variable_integer(stream).await?; + if len == 0 { + return Ok((Self::default(), length_variable_integer)); + } + + let mut properties = $name::default(); + + let mut read_property_bytes = 0; + loop { + let (prop, read_bytes) = PropertyType::async_read(stream).await?; + read_property_bytes += read_bytes; + match prop { + $( + PropertyType::$prop_variant => $crate::packets::macros::properties_read_matches!(stream, properties, read_property_bytes, PropertyType::$prop_variant), + )* + e => return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), + } + if read_property_bytes == len { + break; + } + } + + Ok((properties, length_variable_integer + read_property_bytes)) + } + } + + impl $crate::packets::mqtt_traits::WireLength for $name { + fn wire_len(&self) -> usize { + let mut len: usize = 0; + $( + $crate::packets::macros::properties_wire_length!(self, len , PropertyType::$prop_variant); + )*; + len + } + } + }; +} + +macro_rules! properties_struct { + ( @ $name:ident { } -> ($($result:tt)*) ) => ( + #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] + pub struct $name { + $($result)* + } + ); + ( @ $name:ident { PayloadFormatIndicator, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.2 Payload Format Indicator + /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. + pub payload_format_indicator: Option, + )); + ); + ( @ $name:ident { MessageExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.3 Message Expiry Interval + /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. + pub message_expiry_interval: Option, + )); + ); + ( @ $name:ident { ContentType, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.9 Content Type + /// 3 (0x03) Identifier of the Content Type + pub content_type: Option>, + )); + ); + ( @ $name:ident { ResponseTopic, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.5 Response Topic + /// 8 (0x08) Byte, Identifier of the Response Topic. + pub response_topic: Option>, + )); + ); + ( @ $name:ident { CorrelationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.6 Correlation Data + /// 9 (0x09) Byte, Identifier of the Correlation Data. + pub correlation_data: Option>, + )); + ); + ( @ $name:ident { SubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.8 Subscription Identifier + /// 11 (0x0B), Identifier of the Subscription Identifier. + pub subscription_identifier: Vec, + )); + ); + ( @ $name:ident { SessionExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.2 Session Expiry Interval + /// 17 (0x11) Byte Identifier of the Session Expiry Interval + pub session_expiry_interval: Option, + )); + ); + ( @ $name:ident { ReceiveMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.3 Receive Maximum + /// 33 (0x21) Byte, Identifier of the Receive Maximum + pub receive_maximum: Option, + )); + ); + ( @ $name:ident { TopicAlias, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.4 Topic Alias + /// 35 (0x23) Byte, Identifier of the Topic Alias. + pub topic_alias: Option, + )); + ); + ( @ $name:ident { MaximumQos, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.4 Maximum QoS + /// 36 (0x24) Byte, Identifier of the Maximum QoS. + pub maximum_qos: Option<$crate::packets::QoS>, + )); + ); + ( @ $name:ident { RetainAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.5 Retain Available + /// 37 (0x25) Byte, Identifier of Retain Available. + pub retain_available: Option, + )); + ); + ( @ $name:ident { MaximumPacketSize, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.6 Maximum Packet Size + /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. + pub maximum_packet_size: Option, + )); + ); + ( @ $name:ident { AssignedClientIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.7 Assigned Client Identifier + /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. + pub assigned_client_id: Option>, + )); + ); + ( @ $name:ident { TopicAliasMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.8 Topic Alias Maximum + /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. + pub topic_alias_maximum: Option, + )); + ); + ( @ $name:ident { ReasonString, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.9 Reason String + /// 31 (0x1F) Byte Identifier of the Reason String. + pub reason_string: Option>, + )); + ); + ( @ $name:ident { UserProperty, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.10 User Property + /// 38 (0x26) Byte, Identifier of User Property. + pub user_properties: Vec<(Box, Box)>, + )); + ); + ( @ $name:ident { WildcardSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.11 Wildcard Subscription Available + /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. + pub wildcards_available: Option, + )); + ); + ( @ $name:ident { SubscriptionIdentifierAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.12 Subscription Identifiers Available + /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. + pub subscription_ids_available: Option, + )); + ); + ( @ $name:ident { SharedSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.13 Shared Subscription Available + /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. + pub shared_subscription_available: Option, + )); + ); + ( @ $name:ident { ServerKeepAlive, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.14 Server Keep Alive + /// 19 (0x13) Byte, Identifier of the Server Keep Alive + pub server_keep_alive: Option, + )); + ); + ( @ $name:ident { ResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.15 Response Information + /// 26 (0x1A) Byte, Identifier of the Response Information. + pub response_info: Option>, + )); + ); + ( @ $name:ident { ServerReference, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.16 Server Reference + /// 28 (0x1C) Byte, Identifier of the Server Reference + pub server_reference: Option>, + )); + ); + + ( @ $name:ident { AuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.17 Authentication Method + /// 21 (0x15) Byte, Identifier of the Authentication Method + pub authentication_method: Option>, + )); + ); + ( @ $name:ident { AuthenticationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.18 Authentication Data + /// 22 (0x16) Byte, Identifier of the Authentication Data + // There is a small inconsistency here with authentication_data in the connect packet. + // This is Option while that type uses just Bytes. + pub authentication_data: Option>, + )); + ); + ( @ $name:ident { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + compile_error!(concat!("Unknown property: ", stringify!($unknown))); + ); +} + +macro_rules! properties_read_matches { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SessionExpiryInterval) => { + { + if $properties.session_expiry_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.session_expiry_interval = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::PayloadFormatIndicator) => { + { + if $properties.payload_format_indicator.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.payload_format_indicator = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MessageExpiryInterval) => { + { + if $properties.message_expiry_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.message_expiry_interval = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ContentType) => { + { + if $properties.content_type.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ContentType))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.content_type = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseTopic) => { + { + if $properties.response_topic.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.response_topic = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::CorrelationData) => { + { + if $properties.correlation_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::CorrelationData))); + } + let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.correlation_data = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifier) => { + { + let (prop_body, read_bytes) = VariableInteger::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_identifier.push(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReceiveMaximum) => { + { + if $properties.receive_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.receive_maximum = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumQos) => { + { + if $properties.maximum_qos.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MaximumQos))); + } + let (prop_body, read_bytes) = $crate::packets::QoS::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.maximum_qos = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RetainAvailable) => { + { + if $properties.retain_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.retain_available = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumPacketSize) => { + { + if $properties.maximum_packet_size.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.maximum_packet_size = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AssignedClientIdentifier) => { + { + if $properties.assigned_client_id.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.assigned_client_id = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAliasMaximum) => { + { + if $properties.topic_alias_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.topic_alias_maximum = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReasonString) => { + { + if $properties.reason_string.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.reason_string = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::UserProperty) => { + { + let (prop_body_key, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + let (prop_body_value, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + + $properties.user_properties.push((prop_body_key, prop_body_value)) + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WildcardSubscriptionAvailable) => { + { + if $properties.wildcards_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.wildcards_available = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifierAvailable) => { + { + if $properties.subscription_ids_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_ids_available = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SharedSubscriptionAvailable) => { + { + if $properties.shared_subscription_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.shared_subscription_available = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerKeepAlive) => { + { + if $properties.server_keep_alive.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.server_keep_alive = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseInformation) => { + { + if $properties.response_info.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.response_info = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerReference) => { + { + if $properties.server_reference.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerReference))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.server_reference = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationMethod) => { + { + if $properties.authentication_method.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.authentication_method = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationData) => { + { + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData))); + } + let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.authentication_data = Some(prop_body); + } + }; +} + +macro_rules! properties_wire_length{ + ($self:ident, $len:ident, PropertyType::SessionExpiryInterval) => { + if $self.session_expiry_interval.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::PayloadFormatIndicator) => { + if $self.payload_format_indicator.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::MessageExpiryInterval) => { + if $self.message_expiry_interval.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::TopicAlias) => { + if $self.topic_alias.is_some() { + $len += 3; + } + }; + ($self:ident, $len:ident, PropertyType::ResponseTopic) => { + if let Some(response_topic) = &($self.response_topic) { + $len += 1 + response_topic.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::CorrelationData) => { + if let Some(correlation_data) = &($self.correlation_data) { + $len += 1 + correlation_data.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::SubscriptionIdentifier) => { + for sub_id in &($self.subscription_identifier) { + $len += 1 + $crate::packets::variable_integer_len(*sub_id); + } + }; + ($self:ident, $len:ident, PropertyType::ReceiveMaximum) => { + if $self.receive_maximum.is_some() { + $len += 1 + 2; + } + }; + ($self:ident, $len:ident, PropertyType::MaximumQos) => { + if $self.maximum_qos.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::RetainAvailable) => { + if $self.retain_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::MaximumPacketSize) => { + if $self.maximum_packet_size.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::AssignedClientIdentifier) => { + if let Some(client_id) = $self.assigned_client_id.as_ref() { + $len += 1 + client_id.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::TopicAliasMaximum) => { + if $self.topic_alias_maximum.is_some() { + $len += 1 + 2; + } + }; + ($self:ident, $len:ident, PropertyType::ReasonString) => { + if let Some(reason_string) = &($self.reason_string) { + $len += 1 + reason_string.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::UserProperty) => { + for (key, value) in &($self.user_properties) { + $len += 1; + $len += key.wire_len(); + $len += value.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::WildcardSubscriptionAvailable) => { + if $self.wildcards_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::SubscriptionIdentifierAvailable) => { + if $self.subscription_ids_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::SharedSubscriptionAvailable) => { + if $self.shared_subscription_available.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::ServerKeepAlive) => { + if $self.server_keep_alive.is_some() { + $len += 1 + 2; + } + }; + ($self:ident, $len:ident, PropertyType::ResponseInformation) => { + if let Some(response_info) = &($self.response_info) { + $len += 1 + response_info.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ServerReference) => { + if let Some(server_reference) = &($self.server_reference) { + $len += 1 + server_reference.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::AuthenticationMethod) => { + if let Some(authentication_method) = &($self.authentication_method) { + $len += 1 + authentication_method.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::AuthenticationData) => { + if $self.authentication_data.is_some() && $self.authentication_method.is_some() { + $len += 1 + $self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); + } + }; + ($self:ident, $len:ident, $unknown:ident) => ( + compile_error!(concat!("Unknown property: ", stringify!($unknown))); + ); + +} + +pub(crate) use define_properties; +pub(crate) use properties_struct; +pub(crate) use properties_read_matches; +pub(crate) use properties_wire_length; \ No newline at end of file diff --git a/mqrstt/src/packets/macros/reason_code_macros.rs b/mqrstt/src/packets/macros/reason_code_macros.rs new file mode 100644 index 0000000..6afe8be --- /dev/null +++ b/mqrstt/src/packets/macros/reason_code_macros.rs @@ -0,0 +1,527 @@ +macro_rules! reason_code { + ($name:ident, $($code:ident),*) => { + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub enum $name { + #[default] + $($code),* + } + + impl $crate::packets::mqtt_traits::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + std::marker::Unpin{ + async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { + let input = stream.read_u8().await?; + let res = $crate::packets::macros::reason_code_match!(@ $name, input, { + $($code,)* + } -> ())?; + Ok((res, 1)) + } + } + + impl $crate::packets::mqtt_traits::MqttRead for $name { + fn read(buf: &mut bytes::Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + let res = buf.get_u8(); + $crate::packets::macros::reason_code_match!(@ $name, res, { + $($code,)* + } -> ()) + } + } + + impl $crate::packets::mqtt_traits::MqttWrite for $name { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + let val = $crate::packets::macros::reason_code_match_write!(@ $name, buf, self, { + $($code,)* + } -> ()); + buf.put_u8(val); + Ok(()) + } + } + + }; +} + +macro_rules! reason_code_match { + ( @ $name:ident, $input:ident, { } -> ($($result:tt)*) ) => ( + match $input { + $($result)* + t => Err(DeserializeError::UnknownProperty(t)), + } + ); + ( @ $name:ident, $input:ident, { Success, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x00 => Ok($name::Success), + )) + ); + ( @ $name:ident, $input:ident, { NormalDisconnection, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x00 => Ok($name::NormalDisconnection), + )) + ); + ( @ $name:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x04 => Ok($name::DisconnectWithWillMessage), + )) + ); + ( @ $name:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x18 => Ok($name::ContinueAuthentication), + )) + ); + ( @ $name:ident, $input:ident, { ReAuthenticate, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x19 => Ok($name::ReAuthenticate), + )) + ); + ( @ $name:ident, $input:ident, { UnspecifiedError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x80 => Ok($name::UnspecifiedError), + )) + ); + ( @ $name:ident, $input:ident, { MalformedPacket, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x81 => Ok($name::MalformedPacket), + )) + ); + ( @ $name:ident, $input:ident, { ProtocolError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x82 => Ok($name::ProtocolError), + )) + ); + ( @ $name:ident, $input:ident, { ImplementationSpecificError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x83 => Ok($name::ImplementationSpecificError), + )) + ); + ( @ $name:ident, $input:ident, { UnsupportedProtocolVersion, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x84 => Ok($name::UnsupportedProtocolVersion), + )) + ); + ( @ $name:ident, $input:ident, { ClientIdentifierNotValid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x85 => Ok($name::ClientIdentifierNotValid), + )) + ); + ( @ $name:ident, $input:ident, { BadUsernameOrPassword, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x86 => Ok($name::BadUsernameOrPassword), + )) + ); + ( @ $name:ident, $input:ident, { NotAuthorized, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x87 => Ok($name::NotAuthorized), + )) + ); + ( @ $name:ident, $input:ident, { ServerUnavailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x88 => Ok($name::ServerUnavailable), + )) + ); + ( @ $name:ident, $input:ident, { ServerBusy, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x89 => Ok($name::ServerBusy), + )) + ); + ( @ $name:ident, $input:ident, { Banned, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8A => Ok($name::Banned), + )) + ); + ( @ $name:ident, $input:ident, { ServerShuttingDown, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), + )) + ); + ( @ $name:ident, $input:ident, { BadAuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8C => Ok($name::BadAuthenticationMethod), + )) + ); + ( @ $name:ident, $input:ident, { KeepAliveTimeout, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), + )) + ); + ( @ $name:ident, $input:ident, { SessionTakenOver, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), + )) + ); + ( @ $name:ident, $input:ident, { TopicFilterInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), + )) + ); + ( @ $name:ident, $input:ident, { TopicNameInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x90 => Ok($name::TopicNameInvalid), + )) + ); + ( @ $name:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), + )) + ); + ( @ $name:ident, $input:ident, { TopicAliasInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), + )) + ); + ( @ $name:ident, $input:ident, { PacketTooLarge, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x95 => Ok($name::PacketTooLarge), + )) + ); + ( @ $name:ident, $input:ident, { MessageRateTooHigh, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x96 => Ok($name::MessageRateTooHigh), + )) + ); + ( @ $name:ident, $input:ident, { QuotaExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x97 => Ok($name::QuotaExceeded), + )) + ); + ( @ $name:ident, $input:ident, { AdministrativeAction, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x98 => Ok($name::AdministrativeAction), + )) + ); + ( @ $name:ident, $input:ident, { PayloadFormatInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x99 => Ok($name::PayloadFormatInvalid), + )) + ); + ( @ $name:ident, $input:ident, { RetainNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9A => Ok($name::RetainNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { QosNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9B => Ok($name::QosNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { UseAnotherServer, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9C => Ok($name::UseAnotherServer), + )) + ); + ( @ $name:ident, $input:ident, { ServerMoved, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9D => Ok($name::ServerMoved), + )) + ); + ( @ $name:ident, $input:ident, { SharedSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9E => Ok($name::SharedSubscriptionsNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { ConnectionRateExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x9F => Ok($name::ConnectionRateExceeded), + )) + ); + ( @ $name:ident, $input:ident, { MaximumConnectTime, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0xA0 => Ok($name::MaximumConnectTime), + )) + ); + ( @ $name:ident, $input:ident, { SubscriptionIdentifiersNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0xA1 => Ok($name::SubscriptionIdentifiersNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { WildcardSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0xA2 => Ok($name::WildcardSubscriptionsNotSupported), + )) + ); + ( @ $name:ident, $input:ident, { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + compile_error!(concat!("Unknown reason_code: ", stringify!($unknown))) + ); +} + +macro_rules! reason_code_match_write{ + ( @ $name:ident, $buf:ident, $input:ident, { } -> ($($result:tt)*) ) => ( + match $input { + $($result)* + } + ); + ( @ $name:ident, $buf:ident, $input:ident, { Success, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::Success => 0x00, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { NormalDisconnection, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::NormalDisconnection => 0x00, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::DisconnectWithWillMessage => 0x04, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ContinueAuthentication => 0x18, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ReAuthenticate, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ReAuthenticate => 0x19, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { UnspecifiedError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::UnspecifiedError => 0x80, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { MalformedPacket, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::MalformedPacket => 0x81, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ProtocolError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ProtocolError => 0x82, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ImplementationSpecificError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ImplementationSpecificError => 0x83, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { UnsupportedProtocolVersion, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::UnsupportedProtocolVersion => 0x84, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ClientIdentifierNotValid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ClientIdentifierNotValid => 0x85, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { BadUsernameOrPassword, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::BadUsernameOrPassword => 0x86, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { NotAuthorized, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::NotAuthorized => 0x87, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ServerUnavailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ServerUnavailable => 0x88, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ServerBusy, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ServerBusy => 0x89, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { Banned, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::Banned => 0x8A, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ServerShuttingDown, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ServerShuttingDown => 0x8B , + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { BadAuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::BadAuthenticationMethod => 0x8C, + )) + ); + + ( @ $name:ident, $buf:ident, $input:ident, { KeepAliveTimeout, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::KeepAliveTimeout => 0x8D, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { SessionTakenOver, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::SessionTakenOver => 0x8E, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { TopicFilterInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::TopicFilterInvalid => 0x8F, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { TopicNameInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::TopicNameInvalid => 0x90, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ReceiveMaximumExceeded => 0x93, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { TopicAliasInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::TopicAliasInvalid => 0x94, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { PacketTooLarge, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::PacketTooLarge => 0x95, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { MessageRateTooHigh, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::MessageRateTooHigh => 0x96, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { QuotaExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::QuotaExceeded => 0x97, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { AdministrativeAction, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::AdministrativeAction => 0x98, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { PayloadFormatInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::PayloadFormatInvalid => 0x99, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { RetainNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::RetainNotSupported => 0x9A, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { QosNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::QosNotSupported => 0x9B, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { UseAnotherServer, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::UseAnotherServer => 0x9C, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ServerMoved, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ServerMoved => 0x9D, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { SharedSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::SharedSubscriptionsNotSupported => 0x9E, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { ConnectionRateExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::ConnectionRateExceeded => 0x9F, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { MaximumConnectTime, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::MaximumConnectTime => 0xA0, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { SubscriptionIdentifiersNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::SubscriptionIdentifiersNotSupported => 0xA1, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { WildcardSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::WildcardSubscriptionsNotSupported => 0xA2, + )) + ); + + ( @ $name:ident, $buf:ident, $input:ident, { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + compile_error!(concat!("Unknown reason_code: ", stringify!($unknown))) + ); +} + +pub(crate) use reason_code; +pub(crate) use reason_code_match; +pub(crate) use reason_code_match_write; \ No newline at end of file diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 5f7c6b2..dfcce1f 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -3,6 +3,7 @@ pub mod mqtt_traits; pub mod reason_codes; mod macros; + mod auth; mod connack; mod connect; @@ -38,7 +39,7 @@ use core::slice::Iter; use std::fmt::Display; use self::error::{DeserializeError, ReadBytes, SerializeError}; -use self::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; +use self::mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}; /// Protocol version #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] @@ -68,10 +69,10 @@ impl MqttRead for ProtocolVersion { } } -impl MqttAsyncRead for ProtocolVersion where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { - match buf.read_u8().await { - Ok(5) => Ok(ProtocolVersion::V5), +impl MqttAsyncRead for ProtocolVersion where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + match stream.read_u8().await { + Ok(5) => Ok((ProtocolVersion::V5, 1)), Ok(4) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), Ok(3) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), Ok(_) => Err(ReadError::DeserializeError(DeserializeError::UnknownProtocolVersion)), @@ -123,11 +124,11 @@ impl MqttRead for QoS { } impl MqttAsyncRead for QoS where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { match buf.read_u8().await { - Ok(0) => Ok(QoS::AtMostOnce), - Ok(1) => Ok(QoS::AtLeastOnce), - Ok(2) => Ok(QoS::ExactlyOnce), + Ok(0) => Ok((QoS::AtMostOnce, 1)), + Ok(1) => Ok((QoS::AtLeastOnce, 1)), + Ok(2) => Ok((QoS::ExactlyOnce, 1)), Ok(q) => Err(ReadError::DeserializeError(DeserializeError::UnknownQoS(q))), Err(e) => Err(ReadError::IoError(e)), } @@ -159,11 +160,11 @@ impl MqttRead for Box { } } -impl MqttAsyncRead for Box where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { - let content = Bytes::async_read(buf).await?; - match String::from_utf8(content.to_vec()) { - Ok(s) => Ok(s.into()), +impl MqttAsyncRead for Box where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let (content, read_bytes) = Vec::async_read(stream).await?; + match String::from_utf8(content) { + Ok(s) => Ok((s.into(), read_bytes)), Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), } } @@ -212,10 +213,10 @@ impl MqttRead for String { } impl MqttAsyncRead for String where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { - let content = Bytes::async_read(buf).await?; + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + let (content, read_bytes) = Bytes::async_read(buf).await?; match String::from_utf8(content.to_vec()) { - Ok(s) => Ok(s), + Ok(s) => Ok((s, read_bytes)), Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), } } @@ -255,16 +256,17 @@ impl MqttRead for Bytes { } } -impl MqttAsyncRead for Bytes where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { - let size = buf.read_u16().await?; - let mut data = BytesMut::with_capacity(size as usize); - buf.read_exact(&mut data).await?; - Ok(data.into()) +impl MqttAsyncRead for Bytes where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let size = stream.read_u16().await? as usize; + // let mut data = BytesMut::with_capacity(size); + let mut data = Vec::with_capacity(size); + let read_bytes = stream.read_exact(&mut data).await?; + assert_eq!(size, read_bytes); + Ok((data.into(), 2 + size)) } } - impl MqttWrite for Bytes { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { @@ -282,6 +284,48 @@ impl WireLength for Bytes { } } +impl MqttRead for Vec { + #[inline] + fn read(buf: &mut Bytes) -> Result { + let len = buf.get_u16() as usize; + + if len > buf.len() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + Ok(buf.split_to(len).into()) + } +} +impl MqttWrite for Vec { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.len() as u16); + buf.extend(self); + + Ok(()) + } +} +impl WireLength for Vec { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} + + + +impl MqttAsyncRead for Vec where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let size = stream.read_u16().await? as usize; + // let mut data = BytesMut::with_capacity(size); + let mut data = vec![0u8; size]; + let read_bytes = stream.read_exact(&mut data).await?; + assert_eq!(size, read_bytes); + Ok((data, 2 + size)) + } +} + + impl MqttRead for bool { fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { @@ -297,10 +341,10 @@ impl MqttRead for bool { } impl MqttAsyncRead for bool where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { match buf.read_u8().await? { - 0 => Ok(false), - 1 => Ok(true), + 0 => Ok((false, 1)), + 1 => Ok((true, 1)), _ => Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)), } } @@ -330,8 +374,8 @@ impl MqttRead for u8 { } impl MqttAsyncRead for u8 where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { - Ok(buf.read_u8().await?) + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u8().await?, 1)) } } @@ -346,8 +390,8 @@ impl MqttRead for u16 { } impl MqttAsyncRead for u16 where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { - Ok(buf.read_u16().await?) + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u16().await?, 2)) } } @@ -369,8 +413,8 @@ impl MqttRead for u32 { } } impl MqttAsyncRead for u32 where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { - Ok(buf.read_u32().await?) + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u32().await?, 4)) } } @@ -420,14 +464,31 @@ pub fn read_variable_integer(buf: &mut Bytes) -> Result<(usize, usize), Deserial Err(DeserializeError::MalformedPacket) } -pub fn write_variable_integer(buf: &mut BytesMut, integer: usize) -> Result<(), SerializeError> { +pub async fn read_async_variable_integer(stream: &mut S) -> Result<(usize, usize), ReadError> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; + + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) +} + +pub fn write_variable_integer(buf: &mut BytesMut, integer: usize) -> Result { if integer > 268_435_455 { return Err(SerializeError::VariableIntegerOverflow(integer)); } let mut write = integer; - for _ in 0..4 { + for i in 0..4 { let mut byte = (write % 128) as u8; write /= 128; if write > 0 { @@ -435,7 +496,7 @@ pub fn write_variable_integer(buf: &mut BytesMut, integer: usize) -> Result<(), } buf.put_u8(byte); if write == 0 { - return Ok(()); + return Ok(i+1); } } Err(SerializeError::VariableIntegerOverflow(integer)) @@ -572,9 +633,9 @@ impl MqttRead for PropertyType { } impl MqttAsyncRead for PropertyType where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { match buf.read_u8().await { - Ok(t) => Ok(t.try_into()?), + Ok(t) => Ok((t.try_into()?, 1)), Err(e) => Err(ReadError::IoError(e)), } } diff --git a/mqrstt/src/packets/mqtt_traits.rs b/mqrstt/src/packets/mqtt_traits.rs index 35fb6fd..72ebb9f 100644 --- a/mqrstt/src/packets/mqtt_traits.rs +++ b/mqrstt/src/packets/mqtt_traits.rs @@ -1,12 +1,18 @@ +use std::{future::Future, process::Output}; + use bytes::{Bytes, BytesMut}; use super::error::{DeserializeError, ReadError, SerializeError}; -pub trait VariableHeaderRead: Sized { +pub trait PacketRead: Sized { fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; } -pub trait VariableHeaderWrite: Sized { +pub trait PacketAsyncRead: Sized where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; +} + +pub trait PacketWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError>; } @@ -17,8 +23,12 @@ pub trait WireLength { pub trait MqttRead: Sized { fn read(buf: &mut Bytes) -> Result; } -pub trait MqttAsyncRead: Sized where T: tokio::io::AsyncReadExt { - async fn async_read(buf: &mut T) -> Result; +pub trait MqttAsyncRead: Sized +// where S: tokio::io::AsyncReadExt + Unpin +{ + /// Reads `Self` from the provided stream. + /// Returns the deserialized instance and the number of bytes read from the stream. + fn async_read(stream: &mut S) -> impl Future>; } diff --git a/mqrstt/src/packets/puback.rs b/mqrstt/src/packets/puback.rs index bfba6b8..69b5c63 100644 --- a/mqrstt/src/packets/puback.rs +++ b/mqrstt/src/packets/puback.rs @@ -2,7 +2,7 @@ use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, read_variable_integer, reason_codes::PubAckReasonCode, variable_integer_len, write_variable_integer, PacketType, PropertyType, @@ -15,7 +15,7 @@ pub struct PubAck { pub properties: PubAckProperties, } -impl VariableHeaderRead for PubAck { +impl PacketRead for PubAck { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { @@ -42,7 +42,7 @@ impl VariableHeaderRead for PubAck { } } -impl VariableHeaderWrite for PubAck { +impl PacketWrite for PubAck { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -100,7 +100,7 @@ impl MqttRead for PubAckProperties { let mut properties = PubAckProperties::default(); loop { - match PropertyType::from_u8(u8::read(buf)?)? { + match PropertyType::try_from(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); @@ -155,7 +155,7 @@ impl WireLength for PubAckProperties { #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, puback::{PubAck, PubAckProperties}, reason_codes::PubAckReasonCode, write_variable_integer, PropertyType, diff --git a/mqrstt/src/packets/pubcomp.rs b/mqrstt/src/packets/pubcomp.rs index 69ba679..119699c 100644 --- a/mqrstt/src/packets/pubcomp.rs +++ b/mqrstt/src/packets/pubcomp.rs @@ -2,7 +2,7 @@ use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, read_variable_integer, reason_codes::PubCompReasonCode, write_variable_integer, PacketType, PropertyType, @@ -25,7 +25,7 @@ impl PubComp { } } -impl VariableHeaderRead for PubComp { +impl PacketRead for PubComp { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { @@ -52,7 +52,7 @@ impl VariableHeaderRead for PubComp { } } -impl VariableHeaderWrite for PubComp { +impl PacketWrite for PubComp { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -106,7 +106,7 @@ impl MqttRead for PubCompProperties { let mut properties = PubCompProperties::default(); loop { - match PropertyType::from_u8(u8::read(buf)?)? { + match PropertyType::try_from(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); @@ -161,7 +161,7 @@ impl WireLength for PubCompProperties { #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubcomp::{PubComp, PubCompProperties}, reason_codes::PubCompReasonCode, write_variable_integer, PropertyType, diff --git a/mqrstt/src/packets/publish.rs b/mqrstt/src/packets/publish.rs index d744a9b..15fa609 100644 --- a/mqrstt/src/packets/publish.rs +++ b/mqrstt/src/packets/publish.rs @@ -3,7 +3,7 @@ use bytes::{BufMut, Bytes}; use crate::error::PacketValidationError; use crate::util::constants::MAXIMUM_TOPIC_SIZE; -use super::mqtt_traits::{MqttRead, MqttWrite, PacketValidation, VariableHeaderRead, VariableHeaderWrite, WireLength}; +use super::mqtt_traits::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}; use super::{ error::{DeserializeError, SerializeError}, read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, @@ -51,7 +51,7 @@ impl Publish { } } -impl VariableHeaderRead for Publish { +impl PacketRead for Publish { fn read(flags: u8, _: usize, mut buf: bytes::Bytes) -> Result { let dup = flags & 0b1000 != 0; let qos = QoS::from_u8((flags & 0b110) >> 1)?; @@ -77,7 +77,7 @@ impl VariableHeaderRead for Publish { } } -impl VariableHeaderWrite for Publish { +impl PacketWrite for Publish { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { self.topic.write(buf)?; @@ -122,6 +122,17 @@ impl PacketValidation for Publish { } } +// super::macros::define_properties!(PublishProperties, +// PayloadFormatIndicator, +// MessageExpiryInterval, +// ContentType, +// ResponseTopic, +// CorrelationData, +// SubscriptionIdentifier, +// TopicAlias, +// UserProperty +// ); + #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct PublishProperties { /// 3.3.2.3.2 Payload Format Indicator @@ -132,9 +143,9 @@ pub struct PublishProperties { /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. pub message_expiry_interval: Option, - /// 3.3.2.3.4 Topic Alias - /// 35 (0x23) Byte, Identifier of the Topic Alias. - pub topic_alias: Option, + /// 3.3.2.3.9 Content Type + /// 3 (0x03) Identifier of the Content Type + pub content_type: Option>, /// 3.3.2.3.5 Response Topic /// 8 (0x08) Byte, Identifier of the Response Topic. @@ -148,13 +159,14 @@ pub struct PublishProperties { /// 11 (0x0B), Identifier of the Subscription Identifier. pub subscription_identifier: Vec, + /// 3.3.2.3.4 Topic Alias + /// 35 (0x23) Byte, Identifier of the Topic Alias. + pub topic_alias: Option, + /// 3.3.2.3.7 User Property /// 38 (0x26) Byte, Identifier of the User Property. pub user_properties: Vec<(Box, Box)>, - /// 3.3.2.3.9 Content Type - /// 3 (0x03) Identifier of the Content Type - pub content_type: Option>, } impl MqttRead for PublishProperties { @@ -304,7 +316,7 @@ mod tests { use bytes::{BufMut, BytesMut}; use crate::packets::{ - mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}, + mqtt_traits::{PacketRead, PacketWrite, WireLength}, write_variable_integer, }; @@ -313,6 +325,7 @@ mod tests { #[test] fn test_read_write_properties() { let first_byte = 0b0011_0100; + dbg!("1"); let mut properties = [1, 0, 2].to_vec(); properties.extend(4_294_967_295u32.to_be_bytes()); @@ -338,6 +351,7 @@ mod tests { ] .to_vec(), ); + dbg!("a"); let rem_len = buf_one.len(); @@ -345,12 +359,15 @@ mod tests { let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); + dbg!("b"); + let mut result_buf = BytesMut::new(); + dbg!(p.wire_len()); p.write(&mut result_buf).unwrap(); - dbg!(p.clone()); + // dbg!(p.clone()); - assert_eq!(buf_one.to_vec(), result_buf.to_vec()) + // assert_eq!(buf_one.to_vec(), result_buf.to_vec()) } #[test] diff --git a/mqrstt/src/packets/pubrec.rs b/mqrstt/src/packets/pubrec.rs index 2351b4e..065e421 100644 --- a/mqrstt/src/packets/pubrec.rs +++ b/mqrstt/src/packets/pubrec.rs @@ -2,7 +2,7 @@ use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, read_variable_integer, reason_codes::PubRecReasonCode, write_variable_integer, PacketType, PropertyType, @@ -24,7 +24,7 @@ impl PubRec { } } -impl VariableHeaderRead for PubRec { +impl PacketRead for PubRec { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { @@ -51,7 +51,7 @@ impl VariableHeaderRead for PubRec { } } -impl VariableHeaderWrite for PubRec { +impl PacketWrite for PubRec { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -105,7 +105,7 @@ impl MqttRead for PubRecProperties { let mut properties = PubRecProperties::default(); loop { - match PropertyType::from_u8(u8::read(buf)?)? { + match PropertyType::try_from(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); @@ -160,7 +160,7 @@ impl WireLength for PubRecProperties { #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubrec::{PubRec, PubRecProperties}, reason_codes::PubRecReasonCode, write_variable_integer, PropertyType, diff --git a/mqrstt/src/packets/pubrel.rs b/mqrstt/src/packets/pubrel.rs index aedbc9b..caadaf2 100644 --- a/mqrstt/src/packets/pubrel.rs +++ b/mqrstt/src/packets/pubrel.rs @@ -1,11 +1,7 @@ use bytes::BufMut; use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::PubRelReasonCode, - write_variable_integer, PacketType, PropertyType, + error::{DeserializeError, ReadError}, mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, read_async_variable_integer, read_variable_integer, reason_codes::PubRelReasonCode, write_variable_integer, PacketType, PropertyType }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -25,7 +21,7 @@ impl PubRel { } } -impl VariableHeaderRead for PubRel { +impl PacketRead for PubRel { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. if remaining_length == 2 { @@ -50,7 +46,41 @@ impl VariableHeaderRead for PubRel { } } -impl VariableHeaderWrite for PubRel { +impl PacketAsyncRead for PubRel where S: tokio::io::AsyncReadExt + Unpin { + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut total_read_bytes = 0; + let (packet_identifier, read_bytes) = u16::async_read(stream).await?; + total_read_bytes += read_bytes; + let res = if remaining_length == 2 { + Self { + packet_identifier, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + } + } else { + let (reason_code, read_bytes) = PubRelReasonCode::async_read(stream).await?; + total_read_bytes += read_bytes; + if remaining_length == 3 { + Self { + packet_identifier, + reason_code, + properties: PubRelProperties::default(), + } + } else { + let (properties, read_bytes) = PubRelProperties::async_read(stream).await?; + total_read_bytes += read_bytes; + Self { + packet_identifier, + reason_code, + properties, + } + } + }; + Ok((res, total_read_bytes)) + } +} + +impl PacketWrite for PubRel { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -78,11 +108,13 @@ impl WireLength for PubRel { } } -#[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -pub struct PubRelProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} +// #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] +// pub struct PubRelProperties { +// pub reason_string: Option>, +// pub user_properties: Vec<(Box, Box)>, +// } + +super::macros::define_properties!(PubRelProperties, ReasonString, UserProperty); impl PubRelProperties { pub fn is_empty(&self) -> bool { @@ -104,7 +136,7 @@ impl MqttRead for PubRelProperties { let mut properties = PubRelProperties::default(); loop { - match PropertyType::from_u8(u8::read(buf)?)? { + match PropertyType::try_from(u8::read(buf)?)? { PropertyType::ReasonString => { if properties.reason_string.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); @@ -122,6 +154,47 @@ impl MqttRead for PubRelProperties { } } +// impl MqttAsyncRead for PubRelProperties where S: tokio::io::AsyncReadExt + Unpin { +// async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { +// let (len, length_variable_integer) = read_async_variable_integer(stream).await?; +// if len == 0 { +// return Ok((Self::default(), length_variable_integer)); +// } + +// let mut properties = PubRelProperties::default(); + +// let mut read_property_bytes = 0; +// loop { +// let (prop, read_bytes) = PropertyType::async_read(stream).await?; +// read_property_bytes += read_bytes; +// match prop { +// PropertyType::ReasonString => { +// if properties.reason_string.is_some() { +// return Err(super::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); +// } +// let (prop_body, read_bytes) = Box::::async_read(stream).await?; +// read_property_bytes += read_bytes; +// properties.reason_string = Some(prop_body); +// } +// PropertyType::UserProperty => { +// let (prop_body_key, read_bytes) = Box::::async_read(stream).await?; +// read_property_bytes += read_bytes; +// let (prop_body_value, read_bytes) = Box::::async_read(stream).await?; +// read_property_bytes += read_bytes; + +// properties.user_properties.push((prop_body_key, prop_body_value)) +// }, +// e => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), +// } +// if read_property_bytes == len { +// break; +// } +// } + +// Ok((properties, length_variable_integer + read_property_bytes)) +// } +// } + impl MqttWrite for PubRelProperties { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { let len = self.wire_len(); @@ -142,29 +215,30 @@ impl MqttWrite for PubRelProperties { } } -impl WireLength for PubRelProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } +// impl WireLength for PubRelProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(reason_string) = &self.reason_string { +// len += reason_string.wire_len() + 1; +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } - len - } -} +// len +// } +// } #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, pubrel::{PubRel, PubRelProperties}, reason_codes::PubRelReasonCode, write_variable_integer, PropertyType, }; - use bytes::{BufMut, Bytes, BytesMut}; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + use tokio::{io::ReadBuf, stream}; #[test] fn test_wire_len() { @@ -189,6 +263,29 @@ mod tests { assert_eq!(3, buf.len()); } + + #[test] + fn test_wire_len2() { + let mut buf = BytesMut::new(); + + let prop = PubRelProperties{ + reason_string: Some("reason string, test 1-2-3.".into()), // 26 + 1 + 2 + user_properties: vec![ + ("This is the key".into(), "This is the value".into()), // 32 + 1 + 2 + 2 + ("Another thingy".into(), "The thingy".into()), // 24 + 1 + 2 + 2 + ], + }; + + let len = prop.wire_len(); + // determine length of variable integer + let len_of_wire_len = write_variable_integer(&mut buf, len).unwrap(); + // clear buffer before writing actual properties + buf.clear(); + prop.write(&mut buf).unwrap(); + + assert_eq!(len + len_of_wire_len, buf.len()); + } + #[test] fn test_read_short() { let mut expected_pubrel = PubRel { @@ -217,6 +314,39 @@ mod tests { assert_eq!(expected_pubrel, pubrel); } + #[tokio::test] + async fn test_async_read_short() { + let mut expected_pubrel = PubRel { + packet_identifier: 12, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + + let mut buf = BytesMut::new(); + + expected_pubrel.write(&mut buf).unwrap(); + + assert_eq!(2, buf.len()); + let mut stream: &[u8] = &*buf; + + let (pubrel, read_bytes) = PubRel::async_read(0, 2, &mut stream).await.unwrap(); + + assert_eq!(expected_pubrel, pubrel); + assert_eq!(read_bytes, 2); + + let mut buf = BytesMut::new(); + expected_pubrel.reason_code = PubRelReasonCode::PacketIdentifierNotFound; + expected_pubrel.write(&mut buf).unwrap(); + + assert_eq!(3, buf.len()); + let mut stream: &[u8] = &*buf; + + let (pubrel, read_bytes) = PubRel::async_read(0, 3, &mut stream).await.unwrap(); + assert_eq!(read_bytes, 3); + assert_eq!(expected_pubrel, pubrel); + } + + #[test] fn test_read_simple_pub_rel() { let stream = &[ @@ -235,6 +365,27 @@ mod tests { assert_eq!(expected, p_ack); } + #[tokio::test] + async fn test_async_read_simple_pub_rel() { + let stream = &[ + 0x00, 0x0C, // Packet identifier = 12 + 0x00, // Reason code success + 0x00, // no properties + ]; + + let mut stream = stream.as_ref(); + + let (p_ack, read_bytes) = PubRel::async_read(0, 4, &mut stream).await.unwrap(); + + let expected = PubRel { + packet_identifier: 12, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties::default(), + }; + + assert_eq!(expected, p_ack); + assert_eq!(read_bytes, 4); + } #[test] fn test_read_write_pubrel_with_properties() { @@ -267,6 +418,39 @@ mod tests { assert_eq!(buf.to_vec(), result.to_vec()); } + #[tokio::test] + async fn test_async_read_write_pubrel_with_properties() { + let mut buf = BytesMut::new(); + + buf.put_u16(65_535u16); + buf.put_u8(0x92); + + let mut properties = BytesMut::new(); + PropertyType::ReasonString.write(&mut properties).unwrap(); + "reason string, test 1-2-3.".write(&mut properties).unwrap(); + PropertyType::UserProperty.write(&mut properties).unwrap(); + "This is the key".write(&mut properties).unwrap(); + "This is the value".write(&mut properties).unwrap(); + PropertyType::UserProperty.write(&mut properties).unwrap(); + "Another thingy".write(&mut properties).unwrap(); + "The thingy".write(&mut properties).unwrap(); + + write_variable_integer(&mut buf, properties.len()).unwrap(); + + buf.extend(properties); + + + let mut stream = &*buf; + // flags can be 0 because not used. + // remaining_length must be at least 4 + let (p_ack, read_bytes) = PubRel::async_read(0, buf.len(), &mut stream).await.unwrap(); + + let mut result = BytesMut::new(); + p_ack.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + } + #[test] fn test_properties() { let mut properties_data = BytesMut::new(); @@ -290,6 +474,30 @@ mod tests { assert_eq!(buf.to_vec(), result.to_vec()); } + #[tokio::test] + async fn test_async_read_properties() { + let mut properties_data = BytesMut::new(); + PropertyType::ReasonString.write(&mut properties_data).unwrap(); + "reason string, test 1-2-3.".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); + "This is the key".write(&mut properties_data).unwrap(); + "This is the value".write(&mut properties_data).unwrap(); + PropertyType::UserProperty.write(&mut properties_data).unwrap(); + "Another thingy".write(&mut properties_data).unwrap(); + "The thingy".write(&mut properties_data).unwrap(); + + let mut buf = BytesMut::new(); + write_variable_integer(&mut buf, properties_data.len()).unwrap(); + buf.extend(properties_data); + + let (properties, read_bytes) = PubRelProperties::async_read(&mut &*buf).await.unwrap(); + let mut result = BytesMut::new(); + properties.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + assert_eq!(buf.len(), read_bytes); + } + #[test] fn no_reason_code_or_props() { let mut buf = BytesMut::new(); diff --git a/mqrstt/src/packets/reason_codes.rs b/mqrstt/src/packets/reason_codes.rs index 604189e..90bc37a 100644 --- a/mqrstt/src/packets/reason_codes.rs +++ b/mqrstt/src/packets/reason_codes.rs @@ -3,16 +3,15 @@ use std::default; use bytes::{Buf, BufMut}; use super::error::DeserializeError; -use super::mqtt_traits::{MqttRead, MqttWrite}; +use super::mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite}; + -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum ConnAckReasonCode { - #[default] - Success, +super::macros::reason_code!(ConnAckReasonCode, + Success, UnspecifiedError, MalformedPacket, - ProtocolError, + ProtocolError, ImplementationSpecificError, UnsupportedProtocolVersion, ClientIdentifierNotValid, @@ -30,115 +29,173 @@ pub enum ConnAckReasonCode { QosNotSupported, UseAnotherServer, ServerMoved, - ConnectionRateExceeded, -} - -impl MqttRead for ConnAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(ConnAckReasonCode::Success), - 0x80 => Ok(ConnAckReasonCode::UnspecifiedError), - 0x81 => Ok(ConnAckReasonCode::MalformedPacket), - 0x82 => Ok(ConnAckReasonCode::ProtocolError), - 0x83 => Ok(ConnAckReasonCode::ImplementationSpecificError), - 0x84 => Ok(ConnAckReasonCode::UnsupportedProtocolVersion), - 0x85 => Ok(ConnAckReasonCode::ClientIdentifierNotValid), - 0x86 => Ok(ConnAckReasonCode::BadUsernameOrPassword), - 0x87 => Ok(ConnAckReasonCode::NotAuthorized), - 0x88 => Ok(ConnAckReasonCode::ServerUnavailable), - 0x89 => Ok(ConnAckReasonCode::ServerBusy), - 0x8A => Ok(ConnAckReasonCode::Banned), - 0x8C => Ok(ConnAckReasonCode::BadAuthenticationMethod), - 0x90 => Ok(ConnAckReasonCode::TopicNameInvalid), - 0x95 => Ok(ConnAckReasonCode::PacketTooLarge), - 0x97 => Ok(ConnAckReasonCode::QuotaExceeded), - 0x99 => Ok(ConnAckReasonCode::PayloadFormatInvalid), - 0x9A => Ok(ConnAckReasonCode::RetainNotSupported), - 0x9B => Ok(ConnAckReasonCode::QosNotSupported), - 0x9C => Ok(ConnAckReasonCode::UseAnotherServer), - 0x9D => Ok(ConnAckReasonCode::ServerMoved), - 0x9F => Ok(ConnAckReasonCode::ConnectionRateExceeded), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for ConnAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - ConnAckReasonCode::Success => 0x00, - ConnAckReasonCode::UnspecifiedError => 0x80, - ConnAckReasonCode::MalformedPacket => 0x81, - ConnAckReasonCode::ProtocolError => 0x82, - ConnAckReasonCode::ImplementationSpecificError => 0x83, - ConnAckReasonCode::UnsupportedProtocolVersion => 0x84, - ConnAckReasonCode::ClientIdentifierNotValid => 0x85, - ConnAckReasonCode::BadUsernameOrPassword => 0x86, - ConnAckReasonCode::NotAuthorized => 0x87, - ConnAckReasonCode::ServerUnavailable => 0x88, - ConnAckReasonCode::ServerBusy => 0x89, - ConnAckReasonCode::Banned => 0x8A, - ConnAckReasonCode::BadAuthenticationMethod => 0x8C, - ConnAckReasonCode::TopicNameInvalid => 0x90, - ConnAckReasonCode::PacketTooLarge => 0x95, - ConnAckReasonCode::QuotaExceeded => 0x97, - ConnAckReasonCode::PayloadFormatInvalid => 0x99, - ConnAckReasonCode::RetainNotSupported => 0x9A, - ConnAckReasonCode::QosNotSupported => 0x9B, - ConnAckReasonCode::UseAnotherServer => 0x9C, - ConnAckReasonCode::ServerMoved => 0x9D, - ConnAckReasonCode::ConnectionRateExceeded => 0x9F, - }; - - buf.put_u8(val); - - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum AuthReasonCode { + ConnectionRateExceeded +); + + +// #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +// pub enum ConnAckReasonCode { +// #[default] +// Success, + +// UnspecifiedError, +// MalformedPacket, +// ProtocolError, +// ImplementationSpecificError, +// UnsupportedProtocolVersion, +// ClientIdentifierNotValid, +// BadUsernameOrPassword, +// NotAuthorized, +// ServerUnavailable, +// ServerBusy, +// Banned, +// BadAuthenticationMethod, +// TopicNameInvalid, +// PacketTooLarge, +// QuotaExceeded, +// PayloadFormatInvalid, +// RetainNotSupported, +// QosNotSupported, +// UseAnotherServer, +// ServerMoved, +// ConnectionRateExceeded, +// } + +// impl MqttRead for ConnAckReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } +// let res = buf.get_u8(); + +// crate::packets::macros::reason_code_match!(@ ConnAckReasonCode, res, { +// Success, +// UnspecifiedError, +// MalformedPacket, +// ProtocolError, +// ImplementationSpecificError, +// UnsupportedProtocolVersion, +// ClientIdentifierNotValid, +// BadUsernameOrPassword, +// NotAuthorized, +// ServerUnavailable, +// ServerBusy, +// Banned, +// BadAuthenticationMethod, +// TopicNameInvalid, +// PacketTooLarge, +// QuotaExceeded, +// PayloadFormatInvalid, +// RetainNotSupported, +// QosNotSupported, +// UseAnotherServer, +// ServerMoved, +// ConnectionRateExceeded, +// } -> ()) +// // match buf.get_u8() { +// // 0x00 => Ok(ConnAckReasonCode::Success), +// // 0x80 => Ok(ConnAckReasonCode::UnspecifiedError), +// // 0x81 => Ok(ConnAckReasonCode::MalformedPacket), +// // 0x82 => Ok(ConnAckReasonCode::ProtocolError), +// // 0x83 => Ok(ConnAckReasonCode::ImplementationSpecificError), +// // 0x84 => Ok(ConnAckReasonCode::UnsupportedProtocolVersion), +// // 0x85 => Ok(ConnAckReasonCode::ClientIdentifierNotValid), +// // 0x86 => Ok(ConnAckReasonCode::BadUsernameOrPassword), +// // 0x87 => Ok(ConnAckReasonCode::NotAuthorized), +// // 0x88 => Ok(ConnAckReasonCode::ServerUnavailable), +// // 0x89 => Ok(ConnAckReasonCode::ServerBusy), +// // 0x8A => Ok(ConnAckReasonCode::Banned), +// // 0x8C => Ok(ConnAckReasonCode::BadAuthenticationMethod), +// // 0x90 => Ok(ConnAckReasonCode::TopicNameInvalid), +// // 0x95 => Ok(ConnAckReasonCode::PacketTooLarge), +// // 0x97 => Ok(ConnAckReasonCode::QuotaExceeded), +// // 0x99 => Ok(ConnAckReasonCode::PayloadFormatInvalid), +// // 0x9A => Ok(ConnAckReasonCode::RetainNotSupported), +// // 0x9B => Ok(ConnAckReasonCode::QosNotSupported), +// // 0x9C => Ok(ConnAckReasonCode::UseAnotherServer), +// // 0x9D => Ok(ConnAckReasonCode::ServerMoved), +// // 0x9F => Ok(ConnAckReasonCode::ConnectionRateExceeded), +// // t => Err(DeserializeError::UnknownProperty(t)), +// // } +// } +// } + +// impl MqttWrite for ConnAckReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// ConnAckReasonCode::Success => 0x00, +// ConnAckReasonCode::UnspecifiedError => 0x80, +// ConnAckReasonCode::MalformedPacket => 0x81, +// ConnAckReasonCode::ProtocolError => 0x82, +// ConnAckReasonCode::ImplementationSpecificError => 0x83, +// ConnAckReasonCode::UnsupportedProtocolVersion => 0x84, +// ConnAckReasonCode::ClientIdentifierNotValid => 0x85, +// ConnAckReasonCode::BadUsernameOrPassword => 0x86, +// ConnAckReasonCode::NotAuthorized => 0x87, +// ConnAckReasonCode::ServerUnavailable => 0x88, +// ConnAckReasonCode::ServerBusy => 0x89, +// ConnAckReasonCode::Banned => 0x8A, +// ConnAckReasonCode::BadAuthenticationMethod => 0x8C, +// ConnAckReasonCode::TopicNameInvalid => 0x90, +// ConnAckReasonCode::PacketTooLarge => 0x95, +// ConnAckReasonCode::QuotaExceeded => 0x97, +// ConnAckReasonCode::PayloadFormatInvalid => 0x99, +// ConnAckReasonCode::RetainNotSupported => 0x9A, +// ConnAckReasonCode::QosNotSupported => 0x9B, +// ConnAckReasonCode::UseAnotherServer => 0x9C, +// ConnAckReasonCode::ServerMoved => 0x9D, +// ConnAckReasonCode::ConnectionRateExceeded => 0x9F, +// }; + +// buf.put_u8(val); + +// Ok(()) +// } +// } + +super::macros::reason_code!(AuthReasonCode, Success, ContinueAuthentication, - ReAuthenticate, -} - -impl MqttRead for AuthReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(AuthReasonCode::Success), - 0x18 => Ok(AuthReasonCode::ContinueAuthentication), - 0x19 => Ok(AuthReasonCode::ReAuthenticate), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for AuthReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - AuthReasonCode::Success => 0x00, - AuthReasonCode::ContinueAuthentication => 0x18, - AuthReasonCode::ReAuthenticate => 0x19, - }; - - buf.put_u8(val); - - Ok(()) - } -} - -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum DisconnectReasonCode { - #[default] + ReAuthenticate +); + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum AuthReasonCode { +// Success, +// ContinueAuthentication, +// ReAuthenticate, +// } + +// impl MqttRead for AuthReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(AuthReasonCode::Success), +// 0x18 => Ok(AuthReasonCode::ContinueAuthentication), +// 0x19 => Ok(AuthReasonCode::ReAuthenticate), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttWrite for AuthReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// AuthReasonCode::Success => 0x00, +// AuthReasonCode::ContinueAuthentication => 0x18, +// AuthReasonCode::ReAuthenticate => 0x19, +// }; + +// buf.put_u8(val); + +// Ok(()) +// } +// } + +super::macros::reason_code!(DisconnectReasonCode, NormalDisconnection, DisconnectWithWillMessage, UnspecifiedError, @@ -167,89 +224,125 @@ pub enum DisconnectReasonCode { ConnectionRateExceeded, MaximumConnectTime, SubscriptionIdentifiersNotSupported, - WildcardSubscriptionsNotSupported, -} - -impl MqttRead for DisconnectReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(DisconnectReasonCode::NormalDisconnection), - 0x04 => Ok(DisconnectReasonCode::DisconnectWithWillMessage), - 0x80 => Ok(DisconnectReasonCode::UnspecifiedError), - 0x81 => Ok(DisconnectReasonCode::MalformedPacket), - 0x82 => Ok(DisconnectReasonCode::ProtocolError), - 0x83 => Ok(DisconnectReasonCode::ImplementationSpecificError), - 0x87 => Ok(DisconnectReasonCode::NotAuthorized), - 0x89 => Ok(DisconnectReasonCode::ServerBusy), - 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), - 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), - 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), - 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), - 0x90 => Ok(DisconnectReasonCode::TopicNameInvalid), - 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), - 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), - 0x95 => Ok(DisconnectReasonCode::PacketTooLarge), - 0x96 => Ok(DisconnectReasonCode::MessageRateTooHigh), - 0x97 => Ok(DisconnectReasonCode::QuotaExceeded), - 0x98 => Ok(DisconnectReasonCode::AdministrativeAction), - 0x99 => Ok(DisconnectReasonCode::PayloadFormatInvalid), - 0x9A => Ok(DisconnectReasonCode::RetainNotSupported), - 0x9B => Ok(DisconnectReasonCode::QosNotSupported), - 0x9C => Ok(DisconnectReasonCode::UseAnotherServer), - 0x9D => Ok(DisconnectReasonCode::ServerMoved), - 0x9E => Ok(DisconnectReasonCode::SharedSubscriptionsNotSupported), - 0x9F => Ok(DisconnectReasonCode::ConnectionRateExceeded), - 0xA0 => Ok(DisconnectReasonCode::MaximumConnectTime), - 0xA1 => Ok(DisconnectReasonCode::SubscriptionIdentifiersNotSupported), - 0xA2 => Ok(DisconnectReasonCode::WildcardSubscriptionsNotSupported), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for DisconnectReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - DisconnectReasonCode::NormalDisconnection => 0x00, - DisconnectReasonCode::DisconnectWithWillMessage => 0x04, - DisconnectReasonCode::UnspecifiedError => 0x80, - DisconnectReasonCode::MalformedPacket => 0x81, - DisconnectReasonCode::ProtocolError => 0x82, - DisconnectReasonCode::ImplementationSpecificError => 0x83, - DisconnectReasonCode::NotAuthorized => 0x87, - DisconnectReasonCode::ServerBusy => 0x89, - DisconnectReasonCode::ServerShuttingDown => 0x8B, - DisconnectReasonCode::KeepAliveTimeout => 0x8D, - DisconnectReasonCode::SessionTakenOver => 0x8E, - DisconnectReasonCode::TopicFilterInvalid => 0x8F, - DisconnectReasonCode::TopicNameInvalid => 0x90, - DisconnectReasonCode::ReceiveMaximumExceeded => 0x93, - DisconnectReasonCode::TopicAliasInvalid => 0x94, - DisconnectReasonCode::PacketTooLarge => 0x95, - DisconnectReasonCode::MessageRateTooHigh => 0x96, - DisconnectReasonCode::QuotaExceeded => 0x97, - DisconnectReasonCode::AdministrativeAction => 0x98, - DisconnectReasonCode::PayloadFormatInvalid => 0x99, - DisconnectReasonCode::RetainNotSupported => 0x9A, - DisconnectReasonCode::QosNotSupported => 0x9B, - DisconnectReasonCode::UseAnotherServer => 0x9C, - DisconnectReasonCode::ServerMoved => 0x9D, - DisconnectReasonCode::SharedSubscriptionsNotSupported => 0x9E, - DisconnectReasonCode::ConnectionRateExceeded => 0x9F, - DisconnectReasonCode::MaximumConnectTime => 0xA0, - DisconnectReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, - DisconnectReasonCode::WildcardSubscriptionsNotSupported => 0xA2, - }; - - buf.put_u8(val); - - Ok(()) - } -} + WildcardSubscriptionsNotSupported +); + +// #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum DisconnectReasonCode { +// #[default] +// NormalDisconnection, +// DisconnectWithWillMessage, +// UnspecifiedError, +// MalformedPacket, +// ProtocolError, +// ImplementationSpecificError, +// NotAuthorized, +// ServerBusy, +// ServerShuttingDown, +// KeepAliveTimeout, +// SessionTakenOver, +// TopicFilterInvalid, +// TopicNameInvalid, +// ReceiveMaximumExceeded, +// TopicAliasInvalid, +// PacketTooLarge, +// MessageRateTooHigh, +// QuotaExceeded, +// AdministrativeAction, +// PayloadFormatInvalid, +// RetainNotSupported, +// QosNotSupported, +// UseAnotherServer, +// ServerMoved, +// SharedSubscriptionsNotSupported, +// ConnectionRateExceeded, +// MaximumConnectTime, +// SubscriptionIdentifiersNotSupported, +// WildcardSubscriptionsNotSupported, +// } + +// impl MqttRead for DisconnectReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(DisconnectReasonCode::NormalDisconnection), +// 0x04 => Ok(DisconnectReasonCode::DisconnectWithWillMessage), +// 0x80 => Ok(DisconnectReasonCode::UnspecifiedError), +// 0x81 => Ok(DisconnectReasonCode::MalformedPacket), +// 0x82 => Ok(DisconnectReasonCode::ProtocolError), +// 0x83 => Ok(DisconnectReasonCode::ImplementationSpecificError), +// 0x87 => Ok(DisconnectReasonCode::NotAuthorized), +// 0x89 => Ok(DisconnectReasonCode::ServerBusy), +// 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), +// 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), +// 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), +// 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), +// 0x90 => Ok(DisconnectReasonCode::TopicNameInvalid), +// 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), +// 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), +// 0x95 => Ok(DisconnectReasonCode::PacketTooLarge), +// 0x96 => Ok(DisconnectReasonCode::MessageRateTooHigh), +// 0x97 => Ok(DisconnectReasonCode::QuotaExceeded), +// 0x98 => Ok(DisconnectReasonCode::AdministrativeAction), +// 0x99 => Ok(DisconnectReasonCode::PayloadFormatInvalid), +// 0x9A => Ok(DisconnectReasonCode::RetainNotSupported), +// 0x9B => Ok(DisconnectReasonCode::QosNotSupported), +// 0x9C => Ok(DisconnectReasonCode::UseAnotherServer), +// 0x9D => Ok(DisconnectReasonCode::ServerMoved), +// 0x9E => Ok(DisconnectReasonCode::SharedSubscriptionsNotSupported), +// 0x9F => Ok(DisconnectReasonCode::ConnectionRateExceeded), +// 0xA0 => Ok(DisconnectReasonCode::MaximumConnectTime), +// 0xA1 => Ok(DisconnectReasonCode::SubscriptionIdentifiersNotSupported), +// 0xA2 => Ok(DisconnectReasonCode::WildcardSubscriptionsNotSupported), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttWrite for DisconnectReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// DisconnectReasonCode::NormalDisconnection => 0x00, +// DisconnectReasonCode::DisconnectWithWillMessage => 0x04, +// DisconnectReasonCode::UnspecifiedError => 0x80, +// DisconnectReasonCode::MalformedPacket => 0x81, +// DisconnectReasonCode::ProtocolError => 0x82, +// DisconnectReasonCode::ImplementationSpecificError => 0x83, +// DisconnectReasonCode::NotAuthorized => 0x87, +// DisconnectReasonCode::ServerBusy => 0x89, +// DisconnectReasonCode::ServerShuttingDown => 0x8B, + +// DisconnectReasonCode::KeepAliveTimeout => 0x8D, +// DisconnectReasonCode::SessionTakenOver => 0x8E, +// DisconnectReasonCode::TopicFilterInvalid => 0x8F, + +// DisconnectReasonCode::TopicNameInvalid => 0x90, +// DisconnectReasonCode::ReceiveMaximumExceeded => 0x93, +// DisconnectReasonCode::TopicAliasInvalid => 0x94, +// DisconnectReasonCode::PacketTooLarge => 0x95, +// DisconnectReasonCode::MessageRateTooHigh => 0x96, +// DisconnectReasonCode::QuotaExceeded => 0x97, +// DisconnectReasonCode::AdministrativeAction => 0x98, +// DisconnectReasonCode::PayloadFormatInvalid => 0x99, +// DisconnectReasonCode::RetainNotSupported => 0x9A, +// DisconnectReasonCode::QosNotSupported => 0x9B, +// DisconnectReasonCode::UseAnotherServer => 0x9C, +// DisconnectReasonCode::ServerMoved => 0x9D, +// DisconnectReasonCode::SharedSubscriptionsNotSupported => 0x9E, +// DisconnectReasonCode::ConnectionRateExceeded => 0x9F, +// DisconnectReasonCode::MaximumConnectTime => 0xA0, +// DisconnectReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, +// DisconnectReasonCode::WildcardSubscriptionsNotSupported => 0xA2, +// }; + +// buf.put_u8(val); + +// Ok(()) +// } +// } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub enum PubAckReasonCode { @@ -410,6 +503,17 @@ impl MqttRead for PubRelReasonCode { } } +impl MqttAsyncRead for PubRelReasonCode where S: tokio::io::AsyncReadExt + Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let code = match stream.read_u8().await? { + 0x00 => PubRelReasonCode::Success, + 0x92 => PubRelReasonCode::PacketIdentifierNotFound, + t => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnknownProperty(t))), + }; + Ok((code, 1)) + } +} + impl MqttWrite for PubRelReasonCode { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { let val = match self { diff --git a/mqrstt/src/packets/suback.rs b/mqrstt/src/packets/suback.rs index 3ea0691..95718f2 100644 --- a/mqrstt/src/packets/suback.rs +++ b/mqrstt/src/packets/suback.rs @@ -2,7 +2,7 @@ use bytes::BufMut; use super::{ error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, read_variable_integer, reason_codes::SubAckReasonCode, variable_integer_len, write_variable_integer, PacketType, PropertyType, @@ -18,7 +18,7 @@ pub struct SubAck { pub reason_codes: Vec, } -impl VariableHeaderRead for SubAck { +impl PacketRead for SubAck { fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = SubAckProperties::read(&mut buf)?; @@ -41,7 +41,7 @@ impl VariableHeaderRead for SubAck { } } -impl VariableHeaderWrite for SubAck { +impl PacketWrite for SubAck { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { buf.put_u16(self.packet_identifier); @@ -138,7 +138,7 @@ mod test { use bytes::BytesMut; use super::SubAck; - use crate::packets::mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}; + use crate::packets::mqtt_traits::{PacketRead, PacketWrite}; #[test] fn read_write_suback() { diff --git a/mqrstt/src/packets/subscribe.rs b/mqrstt/src/packets/subscribe.rs index 2c44d03..0142ea5 100644 --- a/mqrstt/src/packets/subscribe.rs +++ b/mqrstt/src/packets/subscribe.rs @@ -2,7 +2,7 @@ use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketValidation, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}, read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, }; use bytes::{Buf, BufMut}; @@ -24,7 +24,7 @@ impl Subscribe { } } -impl VariableHeaderRead for Subscribe { +impl PacketRead for Subscribe { fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = SubscribeProperties::read(&mut buf)?; @@ -48,7 +48,7 @@ impl VariableHeaderRead for Subscribe { } } -impl VariableHeaderWrite for Subscribe { +impl PacketWrite for Subscribe { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -390,7 +390,7 @@ mod tests { use bytes::{Bytes, BytesMut}; use crate::packets::{ - mqtt_traits::{MqttRead, VariableHeaderRead, VariableHeaderWrite}, + mqtt_traits::{MqttRead, PacketRead, PacketWrite}, Packet, }; diff --git a/mqrstt/src/packets/unsuback.rs b/mqrstt/src/packets/unsuback.rs index 5d53378..f7c556b 100644 --- a/mqrstt/src/packets/unsuback.rs +++ b/mqrstt/src/packets/unsuback.rs @@ -1,7 +1,7 @@ use bytes::BufMut; use super::error::{DeserializeError, SerializeError}; -use super::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; +use super::mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}; use super::{read_variable_integer, reason_codes::UnsubAckReasonCode, write_variable_integer, PacketType, PropertyType}; #[derive(Debug, Default, PartialEq, Eq, Clone)] @@ -11,7 +11,7 @@ pub struct UnsubAck { pub reason_codes: Vec, } -impl VariableHeaderRead for UnsubAck { +impl PacketRead for UnsubAck { fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = UnsubAckProperties::read(&mut buf)?; @@ -34,7 +34,7 @@ impl VariableHeaderRead for UnsubAck { } } -impl VariableHeaderWrite for UnsubAck { +impl PacketWrite for UnsubAck { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { buf.put_u16(self.packet_identifier); self.properties.write(buf)?; @@ -125,7 +125,7 @@ mod tests { use bytes::{Bytes, BytesMut}; use crate::packets::{ - mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}, + mqtt_traits::{PacketRead, PacketWrite}, unsuback::UnsubAck, }; diff --git a/mqrstt/src/packets/unsubscribe.rs b/mqrstt/src/packets/unsubscribe.rs index 29f783a..af746ee 100644 --- a/mqrstt/src/packets/unsubscribe.rs +++ b/mqrstt/src/packets/unsubscribe.rs @@ -2,7 +2,7 @@ use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketValidation, VariableHeaderRead, VariableHeaderWrite, WireLength}, + mqtt_traits::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}, read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, }; use bytes::BufMut; @@ -24,7 +24,7 @@ impl Unsubscribe { } } -impl VariableHeaderRead for Unsubscribe { +impl PacketRead for Unsubscribe { fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = UnsubscribeProperties::read(&mut buf)?; @@ -47,7 +47,7 @@ impl VariableHeaderRead for Unsubscribe { } } -impl VariableHeaderWrite for Unsubscribe { +impl PacketWrite for Unsubscribe { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); self.properties.write(buf)?; @@ -280,7 +280,7 @@ mod tests { use bytes::{Bytes, BytesMut}; - use crate::packets::mqtt_traits::{VariableHeaderRead, VariableHeaderWrite}; + use crate::packets::mqtt_traits::{PacketRead, PacketWrite}; use super::Unsubscribe; diff --git a/mqrstt/src/util/constants.rs b/mqrstt/src/util/constants.rs index 023c60d..5ea5612 100644 --- a/mqrstt/src/util/constants.rs +++ b/mqrstt/src/util/constants.rs @@ -1,3 +1,3 @@ -pub const DEFAULT_RECEIVE_MAXIMUM: u16 = 65535; -pub const MAXIMUM_PACKET_SIZE: u32 = 268435455; -pub const MAXIMUM_TOPIC_SIZE: usize = 65535; +pub(crate) const DEFAULT_RECEIVE_MAXIMUM: u16 = 65535; +pub(crate) const MAXIMUM_PACKET_SIZE: u32 = 268435455; +pub(crate) const MAXIMUM_TOPIC_SIZE: usize = 65535; From dfd5a5f15d8dde81072b62f23270dabc7dbe9436 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Sun, 17 Nov 2024 11:47:18 +0100 Subject: [PATCH 03/45] fix: test_read_write_properties in publish stack overflow --- .../src/packets/macros/properties_macros.rs | 455 ++++++++++-------- mqrstt/src/packets/mod.rs | 4 +- mqrstt/src/packets/publish.rs | 172 ++++--- mqrstt/src/tests/test_packets.rs | 10 +- rust-toolchain | 2 +- 5 files changed, 336 insertions(+), 307 deletions(-) diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index 2071051..c7a2c05 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -105,60 +105,61 @@ macro_rules! properties_struct { pub session_expiry_interval: Option, )); ); - ( @ $name:ident { ReceiveMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( - $($result)* - /// 3.2.2.3.3 Receive Maximum - /// 33 (0x21) Byte, Identifier of the Receive Maximum - pub receive_maximum: Option, - )); - ); - ( @ $name:ident { TopicAlias, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { AssignedClientIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.3.2.3.4 Topic Alias - /// 35 (0x23) Byte, Identifier of the Topic Alias. - pub topic_alias: Option, + /// 3.2.2.3.7 Assigned Client Identifier + /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. + pub assigned_client_id: Option>, )); ); - ( @ $name:ident { MaximumQos, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { ServerKeepAlive, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.4 Maximum QoS - /// 36 (0x24) Byte, Identifier of the Maximum QoS. - pub maximum_qos: Option<$crate::packets::QoS>, + /// 3.2.2.3.14 Server Keep Alive + /// 19 (0x13) Byte, Identifier of the Server Keep Alive + pub server_keep_alive: Option, )); ); - ( @ $name:ident { RetainAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { AuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.5 Retain Available - /// 37 (0x25) Byte, Identifier of Retain Available. - pub retain_available: Option, + /// 3.2.2.3.17 Authentication Method + /// 21 (0x15) Byte, Identifier of the Authentication Method + pub authentication_method: Option>, )); ); - ( @ $name:ident { MaximumPacketSize, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { AuthenticationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.6 Maximum Packet Size - /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. - pub maximum_packet_size: Option, + /// 3.2.2.3.18 Authentication Data + /// 22 (0x16) Byte, Identifier of the Authentication Data + pub authentication_data: Option>, )); ); - ( @ $name:ident { AssignedClientIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + // ( @ $name:ident { RequestProblemInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + // // Missing + // ); + // ( @ $name:ident { WillDelayInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + // // Missing + // ); + // ( @ $name:ident { RequestResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + // // Missing + // ); + ( @ $name:ident { ResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.7 Assigned Client Identifier - /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. - pub assigned_client_id: Option>, + /// 3.2.2.3.15 Response Information + /// 26 (0x1A) Byte, Identifier of the Response Information. + pub response_info: Option>, )); ); - ( @ $name:ident { TopicAliasMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { ServerReference, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.8 Topic Alias Maximum - /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. - pub topic_alias_maximum: Option, + /// 3.2.2.3.16 Server Reference + /// 28 (0x1C) Byte, Identifier of the Server Reference + pub server_reference: Option>, )); ); ( @ $name:ident { ReasonString, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -169,79 +170,84 @@ macro_rules! properties_struct { pub reason_string: Option>, )); ); - ( @ $name:ident { UserProperty, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { ReceiveMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.10 User Property - /// 38 (0x26) Byte, Identifier of User Property. - pub user_properties: Vec<(Box, Box)>, + /// 3.2.2.3.3 Receive Maximum + /// 33 (0x21) Byte, Identifier of the Receive Maximum + pub receive_maximum: Option, )); ); - ( @ $name:ident { WildcardSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { TopicAliasMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.11 Wildcard Subscription Available - /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. - pub wildcards_available: Option, + /// 3.2.2.3.8 Topic Alias Maximum + /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. + pub topic_alias_maximum: Option, )); ); - ( @ $name:ident { SubscriptionIdentifierAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { TopicAlias, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.12 Subscription Identifiers Available - /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. - pub subscription_ids_available: Option, + /// 3.3.2.3.4 Topic Alias + /// 35 (0x23) Byte, Identifier of the Topic Alias. + pub topic_alias: Option, )); ); - ( @ $name:ident { SharedSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { MaximumQos, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.13 Shared Subscription Available - /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. - pub shared_subscription_available: Option, + /// 3.2.2.3.4 Maximum QoS + /// 36 (0x24) Byte, Identifier of the Maximum QoS. + pub maximum_qos: Option<$crate::packets::QoS>, )); ); - ( @ $name:ident { ServerKeepAlive, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { RetainAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.14 Server Keep Alive - /// 19 (0x13) Byte, Identifier of the Server Keep Alive - pub server_keep_alive: Option, + /// 3.2.2.3.5 Retain Available + /// 37 (0x25) Byte, Identifier of Retain Available. + pub retain_available: Option, )); ); - ( @ $name:ident { ResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { UserProperty, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.15 Response Information - /// 26 (0x1A) Byte, Identifier of the Response Information. - pub response_info: Option>, + /// 3.2.2.3.10 User Property + /// 38 (0x26) Byte, Identifier of User Property. + pub user_properties: Vec<(Box, Box)>, )); ); - ( @ $name:ident { ServerReference, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { MaximumPacketSize, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.16 Server Reference - /// 28 (0x1C) Byte, Identifier of the Server Reference - pub server_reference: Option>, + /// 3.2.2.3.6 Maximum Packet Size + /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. + pub maximum_packet_size: Option, )); ); - - ( @ $name:ident { AuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { WildcardSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.17 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method - pub authentication_method: Option>, + /// 3.2.2.3.11 Wildcard Subscription Available + /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. + pub wildcards_available: Option, )); ); - ( @ $name:ident { AuthenticationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident { SubscriptionIdentifierAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* - /// 3.2.2.3.18 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - // There is a small inconsistency here with authentication_data in the connect packet. - // This is Option while that type uses just Bytes. - pub authentication_data: Option>, + /// 3.2.2.3.12 Subscription Identifiers Available + /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. + pub subscription_ids_available: Option, + )); + ); + ( @ $name:ident { SharedSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.2.2.3.13 Shared Subscription Available + /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. + pub shared_subscription_available: Option, )); ); ( @ $name:ident { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -250,16 +256,7 @@ macro_rules! properties_struct { } macro_rules! properties_read_matches { - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SessionExpiryInterval) => { - { - if $properties.session_expiry_interval.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval))); - } - let (prop_body, read_bytes) = u32::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.session_expiry_interval = Some(prop_body); - } - }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::PayloadFormatIndicator) => { { if $properties.payload_format_indicator.is_some() { @@ -312,69 +309,89 @@ macro_rules! properties_read_matches { }; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifier) => { { - let (prop_body, read_bytes) = VariableInteger::async_read($stream).await?; + let (prop_body, read_bytes) = $crate::packets::read_async_variable_integer($stream).await?; $read_property_bytes += read_bytes; $properties.subscription_identifier.push(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReceiveMaximum) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SessionExpiryInterval) => { { - if $properties.receive_maximum.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum))); + if $properties.session_expiry_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval))); } - let (prop_body, read_bytes) = u16::async_read($stream).await?; + let (prop_body, read_bytes) = u32::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.receive_maximum = Some(prop_body); + $properties.session_expiry_interval = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumQos) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AssignedClientIdentifier) => { { - if $properties.maximum_qos.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MaximumQos))); + if $properties.assigned_client_id.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier))); } - let (prop_body, read_bytes) = $crate::packets::QoS::async_read($stream).await?; + let (prop_body, read_bytes) = Box::::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.maximum_qos = Some(prop_body); + $properties.assigned_client_id = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RetainAvailable) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerKeepAlive) => { { - if $properties.retain_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + if $properties.server_keep_alive.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive))); } - let (prop_body, read_bytes) = bool::async_read($stream).await?; + let (prop_body, read_bytes) = u16::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.retain_available = Some(prop_body); + $properties.server_keep_alive = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumPacketSize) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationMethod) => { { - if $properties.maximum_packet_size.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + if $properties.authentication_method.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod))); } - let (prop_body, read_bytes) = u32::async_read($stream).await?; + let (prop_body, read_bytes) = Box::::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.maximum_packet_size = Some(prop_body); + $properties.authentication_method = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AssignedClientIdentifier) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationData) => { { - if $properties.assigned_client_id.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier))); + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData))); + } + let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.authentication_data = Some(prop_body); + } + }; + // ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestResponseInformation) => { + // { + // if $properties.authentication_data.is_some() { + // return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation))); + // } + // let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + // $read_property_bytes += read_bytes; + // $properties.authentication_data = Some(prop_body); + // } + // }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseInformation) => { + { + if $properties.response_info.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation))); } let (prop_body, read_bytes) = Box::::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.assigned_client_id = Some(prop_body); + $properties.response_info = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAliasMaximum) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerReference) => { { - if $properties.topic_alias_maximum.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum))); + if $properties.server_reference.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerReference))); } - let (prop_body, read_bytes) = u16::async_read($stream).await?; + let (prop_body, read_bytes) = Box::::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.topic_alias_maximum = Some(prop_body); + $properties.server_reference = Some(prop_body); } }; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReasonString) => { @@ -387,107 +404,113 @@ macro_rules! properties_read_matches { $properties.reason_string = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::UserProperty) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReceiveMaximum) => { { - let (prop_body_key, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - let (prop_body_value, read_bytes) = Box::::async_read($stream).await?; + if $properties.receive_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; $read_property_bytes += read_bytes; - - $properties.user_properties.push((prop_body_key, prop_body_value)) - } + $properties.receive_maximum = Some(prop_body); + } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WildcardSubscriptionAvailable) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAliasMaximum) => { { - if $properties.wildcards_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable))); + if $properties.topic_alias_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum))); } - let (prop_body, read_bytes) = bool::async_read($stream).await?; + let (prop_body, read_bytes) = u16::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.wildcards_available = Some(prop_body); + $properties.topic_alias_maximum = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifierAvailable) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAlias) => { { - if $properties.subscription_ids_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable))); + if $properties.topic_alias.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval))); } - let (prop_body, read_bytes) = bool::async_read($stream).await?; + let (prop_body, read_bytes) = u16::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.subscription_ids_available = Some(prop_body); + $properties.topic_alias = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SharedSubscriptionAvailable) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumQos) => { { - if $properties.shared_subscription_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable))); + if $properties.maximum_qos.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MaximumQos))); } - let (prop_body, read_bytes) = bool::async_read($stream).await?; + let (prop_body, read_bytes) = $crate::packets::QoS::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.shared_subscription_available = Some(prop_body); + $properties.maximum_qos = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerKeepAlive) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RetainAvailable) => { { - if $properties.server_keep_alive.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive))); + if $properties.retain_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); } - let (prop_body, read_bytes) = u16::async_read($stream).await?; + let (prop_body, read_bytes) = bool::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.server_keep_alive = Some(prop_body); + $properties.retain_available = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseInformation) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::UserProperty) => { { - if $properties.response_info.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation))); + let (prop_body_key, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + let (prop_body_value, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + + $properties.user_properties.push((prop_body_key, prop_body_value)) + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumPacketSize) => { + { + if $properties.maximum_packet_size.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; + let (prop_body, read_bytes) = u32::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.response_info = Some(prop_body); + $properties.maximum_packet_size = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerReference) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WildcardSubscriptionAvailable) => { { - if $properties.server_reference.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerReference))); + if $properties.wildcards_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable))); } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; + let (prop_body, read_bytes) = bool::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.server_reference = Some(prop_body); + $properties.wildcards_available = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationMethod) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifierAvailable) => { { - if $properties.authentication_method.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod))); + if $properties.subscription_ids_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable))); } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; + let (prop_body, read_bytes) = bool::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.authentication_method = Some(prop_body); + $properties.subscription_ids_available = Some(prop_body); } }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationData) => { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SharedSubscriptionAvailable) => { { - if $properties.authentication_data.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData))); + if $properties.shared_subscription_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable))); } - let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + let (prop_body, read_bytes) = bool::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.authentication_data = Some(prop_body); + $properties.shared_subscription_available = Some(prop_body); } }; } macro_rules! properties_wire_length{ - ($self:ident, $len:ident, PropertyType::SessionExpiryInterval) => { - if $self.session_expiry_interval.is_some() { - $len += 1 + 4; - } - }; + ($self:ident, $len:ident, PropertyType::PayloadFormatIndicator) => { if $self.payload_format_indicator.is_some() { - $len += 1 + 4; + $len += 2; } }; ($self:ident, $len:ident, PropertyType::MessageExpiryInterval) => { @@ -495,9 +518,9 @@ macro_rules! properties_wire_length{ $len += 1 + 4; } }; - ($self:ident, $len:ident, PropertyType::TopicAlias) => { - if $self.topic_alias.is_some() { - $len += 3; + ($self:ident, $len:ident, PropertyType::ContentType) => { + if let Some(content_type) = &($self.content_type) { + $len += 1 + content_type.wire_len(); } }; ($self:ident, $len:ident, PropertyType::ResponseTopic) => { @@ -515,29 +538,52 @@ macro_rules! properties_wire_length{ $len += 1 + $crate::packets::variable_integer_len(*sub_id); } }; - ($self:ident, $len:ident, PropertyType::ReceiveMaximum) => { - if $self.receive_maximum.is_some() { + ($self:ident, $len:ident, PropertyType::SessionExpiryInterval) => { + if $self.session_expiry_interval.is_some() { + $len += 1 + 4; + } + }; + ($self:ident, $len:ident, PropertyType::AssignedClientIdentifier) => { + if let Some(client_id) = $self.assigned_client_id.as_ref() { + $len += 1 + client_id.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ServerKeepAlive) => { + if $self.server_keep_alive.is_some() { $len += 1 + 2; } }; - ($self:ident, $len:ident, PropertyType::MaximumQos) => { - if $self.maximum_qos.is_some() { - $len += 1 + 1; + ($self:ident, $len:ident, PropertyType::AuthenticationMethod) => { + if let Some(authentication_method) = &($self.authentication_method) { + $len += 1 + authentication_method.wire_len(); } }; - ($self:ident, $len:ident, PropertyType::RetainAvailable) => { - if $self.retain_available.is_some() { - $len += 1 + 1; + ($self:ident, $len:ident, PropertyType::AuthenticationData) => { + if $self.authentication_data.is_some() && $self.authentication_method.is_some() { + $len += 1 + $self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); } }; - ($self:ident, $len:ident, PropertyType::MaximumPacketSize) => { - if $self.maximum_packet_size.is_some() { - $len += 1 + 4; + // ($self:ident, $len:ident, PropertyType::RequestResponseInformation) => { + //Will Delay Interval + // ($self:ident, $len:ident, PropertyType::RequestResponseInformation) => { + ($self:ident, $len:ident, PropertyType::ResponseInformation) => { + if let Some(response_info) = &($self.response_info) { + $len += 1 + response_info.wire_len(); } }; - ($self:ident, $len:ident, PropertyType::AssignedClientIdentifier) => { - if let Some(client_id) = $self.assigned_client_id.as_ref() { - $len += 1 + client_id.wire_len(); + ($self:ident, $len:ident, PropertyType::ServerReference) => { + if let Some(server_reference) = &($self.server_reference) { + $len += 1 + server_reference.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ReasonString) => { + if let Some(reason_string) = &($self.reason_string) { + $len += 1 + reason_string.wire_len(); + } + }; + ($self:ident, $len:ident, PropertyType::ReceiveMaximum) => { + if $self.receive_maximum.is_some() { + $len += 1 + 2; } }; ($self:ident, $len:ident, PropertyType::TopicAliasMaximum) => { @@ -545,9 +591,19 @@ macro_rules! properties_wire_length{ $len += 1 + 2; } }; - ($self:ident, $len:ident, PropertyType::ReasonString) => { - if let Some(reason_string) = &($self.reason_string) { - $len += 1 + reason_string.wire_len(); + ($self:ident, $len:ident, PropertyType::TopicAlias) => { + if $self.topic_alias.is_some() { + $len += 3; + } + }; + ($self:ident, $len:ident, PropertyType::MaximumQos) => { + if $self.maximum_qos.is_some() { + $len += 1 + 1; + } + }; + ($self:ident, $len:ident, PropertyType::RetainAvailable) => { + if $self.retain_available.is_some() { + $len += 1 + 1; } }; ($self:ident, $len:ident, PropertyType::UserProperty) => { @@ -557,6 +613,11 @@ macro_rules! properties_wire_length{ $len += value.wire_len(); } }; + ($self:ident, $len:ident, PropertyType::MaximumPacketSize) => { + if $self.maximum_packet_size.is_some() { + $len += 1 + 4; + } + }; ($self:ident, $len:ident, PropertyType::WildcardSubscriptionAvailable) => { if $self.wildcards_available.is_some() { $len += 1 + 1; @@ -572,35 +633,9 @@ macro_rules! properties_wire_length{ $len += 1 + 1; } }; - ($self:ident, $len:ident, PropertyType::ServerKeepAlive) => { - if $self.server_keep_alive.is_some() { - $len += 1 + 2; - } - }; - ($self:ident, $len:ident, PropertyType::ResponseInformation) => { - if let Some(response_info) = &($self.response_info) { - $len += 1 + response_info.wire_len(); - } - }; - ($self:ident, $len:ident, PropertyType::ServerReference) => { - if let Some(server_reference) = &($self.server_reference) { - $len += 1 + server_reference.wire_len(); - } - }; - ($self:ident, $len:ident, PropertyType::AuthenticationMethod) => { - if let Some(authentication_method) = &($self.authentication_method) { - $len += 1 + authentication_method.wire_len(); - } - }; - ($self:ident, $len:ident, PropertyType::AuthenticationData) => { - if $self.authentication_data.is_some() && $self.authentication_method.is_some() { - $len += 1 + $self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); - } - }; ($self:ident, $len:ident, $unknown:ident) => ( compile_error!(concat!("Unknown property: ", stringify!($unknown))); ); - } pub(crate) use define_properties; diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index dfcce1f..5b31414 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -618,7 +618,7 @@ impl From<&PropertyType> for u8 { impl From for u8 { fn from(value: PropertyType) -> Self { - value.into() + value as u8 } } @@ -1059,7 +1059,7 @@ mod tests { message_expiry_interval: None, topic_alias: None, response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), + correlation_data: Some(b"1212".to_vec()), subscription_identifier: vec![1], user_properties: vec![], content_type: None, diff --git a/mqrstt/src/packets/publish.rs b/mqrstt/src/packets/publish.rs index 15fa609..38059e4 100644 --- a/mqrstt/src/packets/publish.rs +++ b/mqrstt/src/packets/publish.rs @@ -122,52 +122,53 @@ impl PacketValidation for Publish { } } -// super::macros::define_properties!(PublishProperties, -// PayloadFormatIndicator, -// MessageExpiryInterval, -// ContentType, -// ResponseTopic, -// CorrelationData, -// SubscriptionIdentifier, -// TopicAlias, -// UserProperty -// ); - -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct PublishProperties { - /// 3.3.2.3.2 Payload Format Indicator - /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. - pub payload_format_indicator: Option, - - /// 3.3.2.3.3 Message Expiry Interval - /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. - pub message_expiry_interval: Option, - - /// 3.3.2.3.9 Content Type - /// 3 (0x03) Identifier of the Content Type - pub content_type: Option>, - - /// 3.3.2.3.5 Response Topic - /// 8 (0x08) Byte, Identifier of the Response Topic. - pub response_topic: Option>, - - /// 3.3.2.3.6 Correlation Data - /// 9 (0x09) Byte, Identifier of the Correlation Data. - pub correlation_data: Option, - - /// 3.3.2.3.8 Subscription Identifier - /// 11 (0x0B), Identifier of the Subscription Identifier. - pub subscription_identifier: Vec, - - /// 3.3.2.3.4 Topic Alias - /// 35 (0x23) Byte, Identifier of the Topic Alias. - pub topic_alias: Option, - - /// 3.3.2.3.7 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, - -} +super::macros::define_properties!(PublishProperties, + PayloadFormatIndicator, + MessageExpiryInterval, + ContentType, + ResponseTopic, + CorrelationData, + SubscriptionIdentifier, + TopicAlias, + UserProperty +); + + +// #[derive(Debug, Clone, PartialEq, Eq, Default)] +// pub struct PublishProperties { +// /// 3.3.2.3.2 Payload Format Indicator +// /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. +// pub payload_format_indicator: Option, + +// /// 3.3.2.3.3 Message Expiry Interval +// /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. +// pub message_expiry_interval: Option, + +// /// 3.3.2.3.9 Content Type +// /// 3 (0x03) Identifier of the Content Type +// pub content_type: Option>, + +// /// 3.3.2.3.5 Response Topic +// /// 8 (0x08) Byte, Identifier of the Response Topic. +// pub response_topic: Option>, + +// /// 3.3.2.3.6 Correlation Data +// /// 9 (0x09) Byte, Identifier of the Correlation Data. +// pub correlation_data: Option, + +// /// 3.3.2.3.8 Subscription Identifier +// /// 11 (0x0B), Identifier of the Subscription Identifier. +// pub subscription_identifier: Vec, + +// /// 3.3.2.3.4 Topic Alias +// /// 35 (0x23) Byte, Identifier of the Topic Alias. +// pub topic_alias: Option, + +// /// 3.3.2.3.7 User Property +// /// 38 (0x26) Byte, Identifier of the User Property. +// pub user_properties: Vec<(Box, Box)>, + +// } impl MqttRead for PublishProperties { fn read(buf: &mut bytes::Bytes) -> Result { @@ -213,7 +214,7 @@ impl MqttRead for PublishProperties { if properties.correlation_data.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); } - properties.correlation_data = Some(Bytes::read(&mut property_data)?); + properties.correlation_data = Some(Vec::::read(&mut property_data)?); } PropertyType::SubscriptionIdentifier => { properties.subscription_identifier.push(read_variable_integer(&mut property_data)?.0); @@ -278,38 +279,38 @@ impl MqttWrite for PublishProperties { } } -impl WireLength for PublishProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - - if self.payload_format_indicator.is_some() { - len += 2; - } - if self.message_expiry_interval.is_some() { - len += 5; - } - if self.topic_alias.is_some() { - len += 3; - } - if let Some(response_topic) = &self.response_topic { - len += 1 + response_topic.wire_len(); - } - if let Some(correlation_data) = &self.correlation_data { - len += 1 + correlation_data.wire_len(); - } - for sub_id in &self.subscription_identifier { - len += 1 + variable_integer_len(*sub_id); - } - for (key, val) in &self.user_properties { - len += 1 + key.wire_len() + val.wire_len(); - } - if let Some(content_type) = &self.content_type { - len += 1 + content_type.wire_len(); - } - - len - } -} +// impl WireLength for PublishProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; + +// if self.payload_format_indicator.is_some() { +// len += 2; +// } +// if self.message_expiry_interval.is_some() { +// len += 5; +// } +// if self.topic_alias.is_some() { +// len += 3; +// } +// if let Some(response_topic) = &self.response_topic { +// len += 1 + response_topic.wire_len(); +// } +// if let Some(correlation_data) = &self.correlation_data { +// len += 1 + correlation_data.wire_len(); +// } +// for sub_id in &self.subscription_identifier { +// len += 1 + variable_integer_len(*sub_id); +// } +// for (key, val) in &self.user_properties { +// len += 1 + key.wire_len() + val.wire_len(); +// } +// if let Some(content_type) = &self.content_type { +// len += 1 + content_type.wire_len(); +// } + +// len +// } +// } #[cfg(test)] mod tests { @@ -325,7 +326,6 @@ mod tests { #[test] fn test_read_write_properties() { let first_byte = 0b0011_0100; - dbg!("1"); let mut properties = [1, 0, 2].to_vec(); properties.extend(4_294_967_295u32.to_be_bytes()); @@ -351,23 +351,17 @@ mod tests { ] .to_vec(), ); - dbg!("a"); let rem_len = buf_one.len(); - let buf = BytesMut::from(&buf_one[..]); + let buf = buf_one.clone(); let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); - dbg!("b"); - - let mut result_buf = BytesMut::new(); - dbg!(p.wire_len()); + let mut result_buf = BytesMut::with_capacity(1000); p.write(&mut result_buf).unwrap(); - // dbg!(p.clone()); - - // assert_eq!(buf_one.to_vec(), result_buf.to_vec()) + assert_eq!(buf_one.to_vec(), result_buf.to_vec()) } #[test] diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index 747b1d1..5dae121 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -19,7 +19,7 @@ fn publish_packet_1() -> Packet { message_expiry_interval: None, topic_alias: None, response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), + correlation_data: Some(b"1212".to_vec()), subscription_identifier: vec![1], user_properties: vec![], content_type: None, @@ -39,7 +39,7 @@ fn publish_packet_2() -> Packet { message_expiry_interval: Some(3600), topic_alias: Some(1), response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), + correlation_data: Some(b"1212".to_vec()), subscription_identifier: vec![1], user_properties: vec![], content_type: None, @@ -59,7 +59,7 @@ fn publish_packet_3() -> Packet { message_expiry_interval: Some(3600), topic_alias: None, response_topic: Some("Please respond here thank you".into()), - correlation_data: Some(Bytes::from_static(b"5420874")), + correlation_data: Some(b"5420874".to_vec()), subscription_identifier: vec![], user_properties: vec![("blabla".into(), "another blabla".into())], content_type: None, @@ -79,7 +79,7 @@ fn publish_packet_4() -> Packet { message_expiry_interval: Some(3600), topic_alias: Some(1), response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), + correlation_data: Some(b"1212".to_vec()), subscription_identifier: vec![1], user_properties: vec![], content_type: Some("Garbage".into()), @@ -112,7 +112,7 @@ pub fn create_publish_packet(qos: QoS, dup: bool, retain: bool, packet_identifie message_expiry_interval: Some(3600), topic_alias: Some(1), response_topic: None, - correlation_data: Some(Bytes::from_static(b"1212")), + correlation_data: Some(b"1212".to_vec()), subscription_identifier: vec![1], user_properties: vec![], content_type: Some("Garbage".into()), diff --git a/rust-toolchain b/rust-toolchain index 982f51e..9611677 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1,2 +1,2 @@ [toolchain] -channel = "1.75.0" \ No newline at end of file +channel = "1.82.0" \ No newline at end of file From 9b211e5f4683f43bd609d22c5b94cdf2251795d7 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:14:28 +0100 Subject: [PATCH 04/45] delete Cargo.lock --- Cargo.lock | 1944 ---------------------------------------------------- 1 file changed, 1944 deletions(-) delete mode 100644 Cargo.lock diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index a8b4739..0000000 --- a/Cargo.lock +++ /dev/null @@ -1,1944 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "addr2line" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - -[[package]] -name = "aho-corasick" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" -dependencies = [ - "memchr", -] - -[[package]] -name = "anes" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" - -[[package]] -name = "anstyle" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" - -[[package]] -name = "async-channel" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" -dependencies = [ - "concurrent-queue", - "event-listener 2.5.3", - "futures-core", -] - -[[package]] -name = "async-channel" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ca33f4bc4ed1babef42cad36cc1f51fa88be00420404e5b1e80ab1b18f7678c" -dependencies = [ - "concurrent-queue", - "event-listener 4.0.2", - "event-listener-strategy", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-executor" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17ae5ebefcc48e7452b4987947920dac9450be1110cadf34d1b8c116bdbaf97c" -dependencies = [ - "async-lock 3.2.0", - "async-task", - "concurrent-queue", - "fastrand 2.0.1", - "futures-lite 2.1.0", - "slab", -] - -[[package]] -name = "async-fs" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "279cf904654eeebfa37ac9bb1598880884924aab82e290aa65c9e77a0e142e06" -dependencies = [ - "async-lock 2.8.0", - "autocfg", - "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "async-fs" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd1f344136bad34df1f83a47f3fd7f2ab85d75cb8a940af4ccf6d482a84ea01b" -dependencies = [ - "async-lock 3.2.0", - "blocking", - "futures-lite 2.1.0", -] - -[[package]] -name = "async-io" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fc5b45d93ef0529756f812ca52e44c221b35341892d3dcc34132ac02f3dd2af" -dependencies = [ - "async-lock 2.8.0", - "autocfg", - "cfg-if", - "concurrent-queue", - "futures-lite 1.13.0", - "log", - "parking", - "polling 2.8.0", - "rustix 0.37.27", - "slab", - "socket2 0.4.10", - "waker-fn", -] - -[[package]] -name = "async-io" -version = "2.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6afaa937395a620e33dc6a742c593c01aced20aa376ffb0f628121198578ccc7" -dependencies = [ - "async-lock 3.2.0", - "cfg-if", - "concurrent-queue", - "futures-io", - "futures-lite 2.1.0", - "parking", - "polling 3.3.1", - "rustix 0.38.28", - "slab", - "tracing", - "windows-sys 0.52.0", -] - -[[package]] -name = "async-lock" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" -dependencies = [ - "event-listener 2.5.3", -] - -[[package]] -name = "async-lock" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" -dependencies = [ - "event-listener 4.0.2", - "event-listener-strategy", - "pin-project-lite", -] - -[[package]] -name = "async-net" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0434b1ed18ce1cf5769b8ac540e33f01fa9471058b5e89da9e06f3c882a8c12f" -dependencies = [ - "async-io 1.13.0", - "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "async-net" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7" -dependencies = [ - "async-io 2.2.2", - "blocking", - "futures-lite 2.1.0", -] - -[[package]] -name = "async-process" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea6438ba0a08d81529c69b36700fa2f95837bfe3e776ab39cde9c14d9149da88" -dependencies = [ - "async-io 1.13.0", - "async-lock 2.8.0", - "async-signal", - "blocking", - "cfg-if", - "event-listener 3.1.0", - "futures-lite 1.13.0", - "rustix 0.38.28", - "windows-sys 0.48.0", -] - -[[package]] -name = "async-process" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15c1cd5d253ecac3d3cf15e390fd96bd92a13b1d14497d81abf077304794fb04" -dependencies = [ - "async-channel 2.1.1", - "async-io 2.2.2", - "async-lock 3.2.0", - "async-signal", - "blocking", - "cfg-if", - "event-listener 4.0.2", - "futures-lite 2.1.0", - "rustix 0.38.28", - "windows-sys 0.52.0", -] - -[[package]] -name = "async-rustls" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93b21a03b7c21702a0110f9f8d228763a533570deb376119042dabf33c37a01a" -dependencies = [ - "futures-io", - "rustls 0.20.9", - "webpki", -] - -[[package]] -name = "async-rustls" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd10f063fb367d26334e10c50c67ea31ac542b8c3402be2251db4cfc5d74ba66" -dependencies = [ - "futures-io", - "rustls 0.21.10", -] - -[[package]] -name = "async-signal" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e47d90f65a225c4527103a8d747001fc56e375203592b25ad103e1ca13124c5" -dependencies = [ - "async-io 2.2.2", - "async-lock 2.8.0", - "atomic-waker", - "cfg-if", - "futures-core", - "futures-io", - "rustix 0.38.28", - "signal-hook-registry", - "slab", - "windows-sys 0.48.0", -] - -[[package]] -name = "async-task" -version = "4.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d90cd0b264dfdd8eb5bad0a2c217c1f88fa96a8573f40e7b12de23fb468f46" - -[[package]] -name = "async-trait" -version = "0.1.77" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "atomic-waker" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" - -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - -[[package]] -name = "backtrace" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - -[[package]] -name = "base64" -version = "0.21.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" - -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - -[[package]] -name = "bitflags" -version = "2.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" - -[[package]] -name = "blocking" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" -dependencies = [ - "async-channel 2.1.1", - "async-lock 3.2.0", - "async-task", - "fastrand 2.0.1", - "futures-io", - "futures-lite 2.1.0", - "piper", - "tracing", -] - -[[package]] -name = "bumpalo" -version = "3.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" - -[[package]] -name = "bytes" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" - -[[package]] -name = "cast" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" - -[[package]] -name = "cc" -version = "1.0.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "ciborium" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" - -[[package]] -name = "ciborium-ll" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" -dependencies = [ - "ciborium-io", - "half", -] - -[[package]] -name = "clap" -version = "4.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcfab8ba68f3668e89f6ff60f5b205cea56aa7b769451a59f34b8682f51c056d" -dependencies = [ - "clap_builder", -] - -[[package]] -name = "clap_builder" -version = "4.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb7fb5e4e979aec3be7791562fcba452f94ad85e954da024396433e0e25a79e9" -dependencies = [ - "anstyle", - "clap_lex", -] - -[[package]] -name = "clap_lex" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" - -[[package]] -name = "concurrent-queue" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "criterion" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" -dependencies = [ - "anes", - "cast", - "ciborium", - "clap", - "criterion-plot", - "futures", - "is-terminal", - "itertools", - "num-traits", - "once_cell", - "oorandom", - "plotters", - "rayon", - "regex", - "serde", - "serde_derive", - "serde_json", - "tinytemplate", - "tokio", - "walkdir", -] - -[[package]] -name = "criterion-plot" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" -dependencies = [ - "cast", - "itertools", -] - -[[package]] -name = "crossbeam-deque" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" -dependencies = [ - "cfg-if", - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" -dependencies = [ - "autocfg", - "cfg-if", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - -[[package]] -name = "errno" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "event-listener" -version = "2.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" - -[[package]] -name = "event-listener" -version = "3.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93877bcde0eb80ca09131a08d23f0a5c18a620b01db137dba666d18cd9b30c2" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener" -version = "4.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "218a870470cce1469024e9fb66b901aa983929d81304a1cdb299f28118e550d5" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - -[[package]] -name = "event-listener-strategy" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" -dependencies = [ - "event-listener 4.0.2", - "pin-project-lite", -] - -[[package]] -name = "fastrand" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] - -[[package]] -name = "fastrand" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" - -[[package]] -name = "futures" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" -dependencies = [ - "futures-channel", - "futures-core", - "futures-executor", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" - -[[package]] -name = "futures-executor" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-io" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" - -[[package]] -name = "futures-lite" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49a9d51ce47660b1e808d3c990b4709f2f415d928835a17dfd16991515c46bce" -dependencies = [ - "fastrand 1.9.0", - "futures-core", - "futures-io", - "memchr", - "parking", - "pin-project-lite", - "waker-fn", -] - -[[package]] -name = "futures-lite" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aeee267a1883f7ebef3700f262d2d54de95dfaf38189015a74fdc4e0c7ad8143" -dependencies = [ - "fastrand 2.0.1", - "futures-core", - "futures-io", - "parking", - "pin-project-lite", -] - -[[package]] -name = "futures-macro" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "futures-sink" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" - -[[package]] -name = "futures-task" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" - -[[package]] -name = "futures-timer" -version = "3.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" - -[[package]] -name = "futures-util" -version = "0.3.30" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-macro", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "slab", -] - -[[package]] -name = "getrandom" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "gimli" -version = "0.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" - -[[package]] -name = "glob" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" - -[[package]] -name = "half" -version = "1.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" - -[[package]] -name = "hermit-abi" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" - -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "io-lifetimes" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.48.0", -] - -[[package]] -name = "is-terminal" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" -dependencies = [ - "hermit-abi", - "rustix 0.38.28", - "windows-sys 0.52.0", -] - -[[package]] -name = "itertools" -version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" -dependencies = [ - "either", -] - -[[package]] -name = "itoa" -version = "1.0.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" - -[[package]] -name = "js-sys" -version = "0.3.66" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" -dependencies = [ - "wasm-bindgen", -] - -[[package]] -name = "lazy_static" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" - -[[package]] -name = "libc" -version = "0.2.151" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" - -[[package]] -name = "linux-raw-sys" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" - -[[package]] -name = "linux-raw-sys" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" - -[[package]] -name = "log" -version = "0.4.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" - -[[package]] -name = "matchers" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" -dependencies = [ - "regex-automata 0.1.10", -] - -[[package]] -name = "memchr" -version = "2.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" - -[[package]] -name = "miniz_oxide" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" -dependencies = [ - "adler", -] - -[[package]] -name = "mio" -version = "0.8.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" -dependencies = [ - "libc", - "wasi", - "windows-sys 0.48.0", -] - -[[package]] -name = "mqrstt" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a854e678a3a205f8ac238694a8aae480684784f604a29367e47bd93a67bad83c" -dependencies = [ - "async-channel 1.9.0", - "async-trait", - "bytes", - "futures", - "smol 1.3.0", - "thiserror", - "tokio", -] - -[[package]] -name = "mqrstt" -version = "0.3.0" -dependencies = [ - "async-channel 2.1.1", - "async-rustls 0.4.1", - "bytes", - "criterion", - "futures", - "rand", - "rstest", - "rustls 0.21.10", - "rustls-pemfile", - "smol 2.0.0", - "thiserror", - "tokio", - "tokio-rustls", - "tracing", - "tracing-subscriber", - "webpki", -] - -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - -[[package]] -name = "num-traits" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - -[[package]] -name = "object" -version = "0.32.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" -dependencies = [ - "memchr", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "oorandom" -version = "11.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" - -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - -[[package]] -name = "parking" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" - -[[package]] -name = "pin-project-lite" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "piper" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" -dependencies = [ - "atomic-waker", - "fastrand 2.0.1", - "futures-io", -] - -[[package]] -name = "plotters" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" - -[[package]] -name = "plotters-svg" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" -dependencies = [ - "plotters-backend", -] - -[[package]] -name = "polling" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b2d323e8ca7996b3e23126511a523f7e62924d93ecd5ae73b333815b0eb3dce" -dependencies = [ - "autocfg", - "bitflags 1.3.2", - "cfg-if", - "concurrent-queue", - "libc", - "log", - "pin-project-lite", - "windows-sys 0.48.0", -] - -[[package]] -name = "polling" -version = "3.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf63fa624ab313c11656b4cda960bfc46c410187ad493c41f6ba2d8c1e991c9e" -dependencies = [ - "cfg-if", - "concurrent-queue", - "pin-project-lite", - "rustix 0.38.28", - "tracing", - "windows-sys 0.52.0", -] - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "proc-macro2" -version = "1.0.74" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2de98502f212cfcea8d0bb305bd0f49d7ebdd75b64ba0a68f937d888f4e0d6db" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rayon" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - -[[package]] -name = "regex" -version = "1.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", -] - -[[package]] -name = "regex-automata" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax 0.8.2", -] - -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - -[[package]] -name = "regex-syntax" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" - -[[package]] -name = "relative-path" -version = "1.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e898588f33fdd5b9420719948f9f2a32c922a246964576f71ba7f24f80610fbc" - -[[package]] -name = "ring" -version = "0.16.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" -dependencies = [ - "cc", - "libc", - "once_cell", - "spin 0.5.2", - "untrusted 0.7.1", - "web-sys", - "winapi", -] - -[[package]] -name = "ring" -version = "0.17.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" -dependencies = [ - "cc", - "getrandom", - "libc", - "spin 0.9.8", - "untrusted 0.9.0", - "windows-sys 0.48.0", -] - -[[package]] -name = "rstest" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" -dependencies = [ - "futures", - "futures-timer", - "rstest_macros", - "rustc_version", -] - -[[package]] -name = "rstest_macros" -version = "0.18.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" -dependencies = [ - "cfg-if", - "glob", - "proc-macro2", - "quote", - "regex", - "relative-path", - "rustc_version", - "syn", - "unicode-ident", -] - -[[package]] -name = "rustc-demangle" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" - -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - -[[package]] -name = "rustix" -version = "0.37.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" -dependencies = [ - "bitflags 1.3.2", - "errno", - "io-lifetimes", - "libc", - "linux-raw-sys 0.3.8", - "windows-sys 0.48.0", -] - -[[package]] -name = "rustix" -version = "0.38.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" -dependencies = [ - "bitflags 2.4.1", - "errno", - "libc", - "linux-raw-sys 0.4.12", - "windows-sys 0.52.0", -] - -[[package]] -name = "rustls" -version = "0.20.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" -dependencies = [ - "log", - "ring 0.16.20", - "sct", - "webpki", -] - -[[package]] -name = "rustls" -version = "0.21.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" -dependencies = [ - "log", - "ring 0.17.7", - "rustls-webpki", - "sct", -] - -[[package]] -name = "rustls-pemfile" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" -dependencies = [ - "base64", -] - -[[package]] -name = "rustls-webpki" -version = "0.101.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" -dependencies = [ - "ring 0.17.7", - "untrusted 0.9.0", -] - -[[package]] -name = "ryu" -version = "1.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" - -[[package]] -name = "same-file" -version = "1.0.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" -dependencies = [ - "winapi-util", -] - -[[package]] -name = "sct" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" -dependencies = [ - "ring 0.17.7", - "untrusted 0.9.0", -] - -[[package]] -name = "semver" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" - -[[package]] -name = "serde" -version = "1.0.194" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b114498256798c94a0689e1a15fec6005dee8ac1f41de56404b67afc2a4b773" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.194" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3385e45322e8f9931410f01b3031ec534c3947d0e94c18049af4d9f9907d4e0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.110" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fbd975230bada99c8bb618e0c365c2eefa219158d5c6c29610fd09ff1833257" -dependencies = [ - "itoa", - "ryu", - "serde", -] - -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - -[[package]] -name = "signal-hook-registry" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" -dependencies = [ - "libc", -] - -[[package]] -name = "slab" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] - -[[package]] -name = "smallvec" -version = "1.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" - -[[package]] -name = "smol" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13f2b548cd8447f8de0fdf1c592929f70f4fc7039a05e47404b0d096ec6987a1" -dependencies = [ - "async-channel 1.9.0", - "async-executor", - "async-fs 1.6.0", - "async-io 1.13.0", - "async-lock 2.8.0", - "async-net 1.8.0", - "async-process 1.8.1", - "blocking", - "futures-lite 1.13.0", -] - -[[package]] -name = "smol" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e635339259e51ef85ac7aa29a1cd991b957047507288697a690e80ab97d07cad" -dependencies = [ - "async-channel 2.1.1", - "async-executor", - "async-fs 2.1.0", - "async-io 2.2.2", - "async-lock 3.2.0", - "async-net 2.0.0", - "async-process 2.0.1", - "blocking", - "futures-lite 2.1.0", -] - -[[package]] -name = "smol_tcp_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-rustls 0.3.0", - "async-trait", - "futures", - "mqrstt 0.2.2", - "rustls 0.20.9", - "rustls-pemfile", - "smol 1.3.0", - "webpki", -] - -[[package]] -name = "smol_tls_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-rustls 0.3.0", - "async-trait", - "futures", - "mqrstt 0.2.2", - "rustls 0.20.9", - "rustls-pemfile", - "smol 1.3.0", - "webpki", -] - -[[package]] -name = "socket2" -version = "0.4.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" -dependencies = [ - "libc", - "winapi", -] - -[[package]] -name = "socket2" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" -dependencies = [ - "libc", - "windows-sys 0.48.0", -] - -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - -[[package]] -name = "syn" -version = "2.0.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89456b690ff72fddcecf231caedbe615c59480c93358a93dfae7fc29e3ebbf0e" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "sync_tcp_v0_2_2" -version = "0.1.0" -dependencies = [ - "mqrstt 0.2.2", -] - -[[package]] -name = "thiserror" -version = "1.0.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "1.0.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thread_local" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" -dependencies = [ - "cfg-if", - "once_cell", -] - -[[package]] -name = "tinytemplate" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" -dependencies = [ - "serde", - "serde_json", -] - -[[package]] -name = "tokio" -version = "1.35.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" -dependencies = [ - "backtrace", - "bytes", - "libc", - "mio", - "num_cpus", - "pin-project-lite", - "socket2 0.5.5", - "tokio-macros", - "windows-sys 0.48.0", -] - -[[package]] -name = "tokio-macros" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-rustls" -version = "0.24.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" -dependencies = [ - "rustls 0.21.10", - "tokio", -] - -[[package]] -name = "tokio_tcp_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-trait", - "mqrstt 0.2.2", - "rustls-pemfile", - "tokio", - "tokio-rustls", - "webpki", -] - -[[package]] -name = "tokio_tls_v0_2_2" -version = "0.1.0" -dependencies = [ - "async-trait", - "mqrstt 0.2.2", - "rustls 0.20.9", - "rustls-pemfile", - "tokio", - "tokio-rustls", - "webpki", -] - -[[package]] -name = "tracing" -version = "0.1.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" -dependencies = [ - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tracing-core" -version = "0.1.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" -dependencies = [ - "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" -dependencies = [ - "matchers", - "nu-ansi-term", - "once_cell", - "regex", - "sharded-slab", - "smallvec", - "thread_local", - "tracing", - "tracing-core", - "tracing-log", -] - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "untrusted" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "valuable" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" - -[[package]] -name = "waker-fn" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3c4517f54858c779bbcbf228f4fca63d121bf85fbecb2dc578cdf4a39395690" - -[[package]] -name = "walkdir" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" -dependencies = [ - "same-file", - "winapi-util", -] - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" - -[[package]] -name = "web-sys" -version = "0.3.66" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - -[[package]] -name = "webpki" -version = "0.22.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" -dependencies = [ - "ring 0.17.7", - "untrusted 0.9.0", -] - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" -dependencies = [ - "winapi", -] - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows-sys" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" -dependencies = [ - "windows-targets 0.48.5", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.0", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - -[[package]] -name = "windows-targets" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" -dependencies = [ - "windows_aarch64_gnullvm 0.52.0", - "windows_aarch64_msvc 0.52.0", - "windows_i686_gnu 0.52.0", - "windows_i686_msvc 0.52.0", - "windows_x86_64_gnu 0.52.0", - "windows_x86_64_gnullvm 0.52.0", - "windows_x86_64_msvc 0.52.0", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" From 0268b04fad15054b310abc197e9f6b9417e9bb93 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:16:44 +0100 Subject: [PATCH 05/45] delete examples folder --- examples/.gitignore | 1 - examples/smol_tcp_v0.2.2/Cargo.toml | 20 --- examples/smol_tcp_v0.2.2/src/main.rs | 68 --------- examples/smol_tls_v0.2.2/Cargo.toml | 20 --- .../smol_tls_v0.2.2/src/broker.emqx.io-ca.crt | 22 --- examples/smol_tls_v0.2.2/src/main.rs | 131 ------------------ examples/sync_tcp_v0.2.2/Cargo.toml | 10 -- examples/sync_tcp_v0.2.2/src/main.rs | 68 --------- examples/tokio_tcp_v0.2.2/Cargo.toml | 18 --- examples/tokio_tcp_v0.2.2/src/main.rs | 69 --------- examples/tokio_tls_v0.2.2/Cargo.toml | 19 --- .../src/broker.emqx.io-ca.crt | 22 --- examples/tokio_tls_v0.2.2/src/main.rs | 131 ------------------ 13 files changed, 599 deletions(-) delete mode 100644 examples/.gitignore delete mode 100644 examples/smol_tcp_v0.2.2/Cargo.toml delete mode 100644 examples/smol_tcp_v0.2.2/src/main.rs delete mode 100644 examples/smol_tls_v0.2.2/Cargo.toml delete mode 100644 examples/smol_tls_v0.2.2/src/broker.emqx.io-ca.crt delete mode 100644 examples/smol_tls_v0.2.2/src/main.rs delete mode 100644 examples/sync_tcp_v0.2.2/Cargo.toml delete mode 100644 examples/sync_tcp_v0.2.2/src/main.rs delete mode 100644 examples/tokio_tcp_v0.2.2/Cargo.toml delete mode 100644 examples/tokio_tcp_v0.2.2/src/main.rs delete mode 100644 examples/tokio_tls_v0.2.2/Cargo.toml delete mode 100644 examples/tokio_tls_v0.2.2/src/broker.emqx.io-ca.crt delete mode 100644 examples/tokio_tls_v0.2.2/src/main.rs diff --git a/examples/.gitignore b/examples/.gitignore deleted file mode 100644 index ce9f139..0000000 --- a/examples/.gitignore +++ /dev/null @@ -1 +0,0 @@ -**/Cargo.lock \ No newline at end of file diff --git a/examples/smol_tcp_v0.2.2/Cargo.toml b/examples/smol_tcp_v0.2.2/Cargo.toml deleted file mode 100644 index 33ec470..0000000 --- a/examples/smol_tcp_v0.2.2/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "smol_tcp_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["smol"]} - -smol = { version = "1.3.0" } -futures = { version = "0.3.27", default-features = false, features = ["std", "async-await"] } - -async-trait = "0.1.68" - -rustls = { version = "0.20.7" } -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } -async-rustls = { version = "0.3.0" } diff --git a/examples/smol_tcp_v0.2.2/src/main.rs b/examples/smol_tcp_v0.2.2/src/main.rs deleted file mode 100644 index 5254322..0000000 --- a/examples/smol_tcp_v0.2.2/src/main.rs +++ /dev/null @@ -1,68 +0,0 @@ -use async_trait::async_trait; -use mqrstt::{ - new_smol, - packets::{self, Packet}, - smol::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -fn main() { - smol::block_on(async { - let client_id = "SmolTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_smol(options); - - let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(stream, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = futures::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); - }); -} diff --git a/examples/smol_tls_v0.2.2/Cargo.toml b/examples/smol_tls_v0.2.2/Cargo.toml deleted file mode 100644 index e5402d5..0000000 --- a/examples/smol_tls_v0.2.2/Cargo.toml +++ /dev/null @@ -1,20 +0,0 @@ -[package] -name = "smol_tls_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["smol"]} - -smol = { version = "1.3.0" } -futures = { version = "0.3.27", default-features = false, features = ["std", "async-await"] } - -async-trait = "0.1.68" - -rustls = { version = "0.20.7" } -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } -async-rustls = { version = "0.3.0" } diff --git a/examples/smol_tls_v0.2.2/src/broker.emqx.io-ca.crt b/examples/smol_tls_v0.2.2/src/broker.emqx.io-ca.crt deleted file mode 100644 index fd4341d..0000000 --- a/examples/smol_tls_v0.2.2/src/broker.emqx.io-ca.crt +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIQCDvgVpBCRrGhdWrJWZHHSjANBgkqhkiG9w0BAQUFADBh -MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 -d3cuZGlnaWNlcnQuY29tMSAwHgYDVQQDExdEaWdpQ2VydCBHbG9iYWwgUm9vdCBD -QTAeFw0wNjExMTAwMDAwMDBaFw0zMTExMTAwMDAwMDBaMGExCzAJBgNVBAYTAlVT -MRUwEwYDVQQKEwxEaWdpQ2VydCBJbmMxGTAXBgNVBAsTEHd3dy5kaWdpY2VydC5j -b20xIDAeBgNVBAMTF0RpZ2lDZXJ0IEdsb2JhbCBSb290IENBMIIBIjANBgkqhkiG -9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4jvhEXLeqKTTo1eqUKKPC3eQyaKl7hLOllsB -CSDMAZOnTjC3U/dDxGkAV53ijSLdhwZAAIEJzs4bg7/fzTtxRuLWZscFs3YnFo97 -nh6Vfe63SKMI2tavegw5BmV/Sl0fvBf4q77uKNd0f3p4mVmFaG5cIzJLv07A6Fpt -43C/dxC//AH2hdmoRBBYMql1GNXRor5H4idq9Joz+EkIYIvUX7Q6hL+hqkpMfT7P -T19sdl6gSzeRntwi5m3OFBqOasv+zbMUZBfHWymeMr/y7vrTC0LUq7dBMtoM1O/4 -gdW7jVg/tRvoSSiicNoxBN33shbyTApOB6jtSj1etX+jkMOvJwIDAQABo2MwYTAO -BgNVHQ8BAf8EBAMCAYYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUA95QNVbR -TLtm8KPiGxvDl7I90VUwHwYDVR0jBBgwFoAUA95QNVbRTLtm8KPiGxvDl7I90VUw -DQYJKoZIhvcNAQEFBQADggEBAMucN6pIExIK+t1EnE9SsPTfrgT1eXkIoyQY/Esr -hMAtudXH/vTBH1jLuG2cenTnmCmrEbXjcKChzUyImZOMkXDiqw8cvpOp/2PV5Adg -06O/nVsJ8dWO41P0jmP6P6fbtGbfYmbW0W5BjfIttep3Sp+dWOIrWcBAI+0tKIJF -PnlUkiaY4IBIqDfv8NZ5YBberOgOzW6sRBc4L0na4UU+Krk2U886UAb3LujEV0ls -YSEY1QSteDwsOoBrp+uvFRTp2InBuThs4pFsiv9kuXclVzDAGySj4dzp30d8tbQk -CAUw7C29C79Fv1C5qfPrmAESrciIxpg0X40KPMbp1ZWVbd4= ------END CERTIFICATE----- diff --git a/examples/smol_tls_v0.2.2/src/main.rs b/examples/smol_tls_v0.2.2/src/main.rs deleted file mode 100644 index ee8e754..0000000 --- a/examples/smol_tls_v0.2.2/src/main.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::{ - io::{BufReader, Cursor}, - sync::Arc, -}; - -use async_trait::async_trait; -use mqrstt::{ - new_smol, - packets::{self, Packet}, - smol::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; -use rustls::{Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; - -pub const EMQX_CERT: &[u8] = include_bytes!("broker.emqx.io-ca.crt"); - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -#[derive(Debug, Clone)] -pub enum PrivateKey { - RSA(Vec), - ECC(Vec), -} - -pub fn simple_rust_tls(ca: Vec, alpn: Option>>, client_auth: Option<(Vec, PrivateKey)>) -> Result, rustls::Error> { - let mut root_cert_store = RootCertStore::empty(); - - let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca))).unwrap(); - - let trust_anchors = ca_certs.iter().map_while(|cert| { - if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { - Some(OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)) - } else { - None - } - }); - root_cert_store.add_server_trust_anchors(trust_anchors); - - assert!(!root_cert_store.is_empty()); - - let config = ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store); - - let mut config = match client_auth { - Some((client_cert_info, client_private_info)) => { - let read_private_keys = match client_private_info { - PrivateKey::RSA(rsa) => rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))), - PrivateKey::ECC(ecc) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))), - } - .unwrap(); - - let key = read_private_keys.into_iter().next().unwrap(); - - let client_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info))).unwrap(); - let client_cert_chain = client_certs.into_iter().map(Certificate).collect(); - - config.with_single_cert(client_cert_chain, rustls::PrivateKey(key))? - } - None => config.with_no_client_auth(), - }; - - if let Some(alpn) = alpn { - config.alpn_protocols.extend(alpn) - } - - Ok(Arc::new(config)) -} - -fn main() { - smol::block_on(async { - let client_id = "SmolTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_smol(options); - - let arc_client_config = simple_rust_tls(EMQX_CERT.to_vec(), None, None).unwrap(); - - let domain = ServerName::try_from(address).unwrap(); - let connector = async_rustls::TlsConnector::from(arc_client_config); - - let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); - let connection = connector.connect(domain, stream).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(connection, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = futures::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); - }); -} diff --git a/examples/sync_tcp_v0.2.2/Cargo.toml b/examples/sync_tcp_v0.2.2/Cargo.toml deleted file mode 100644 index c06ba83..0000000 --- a/examples/sync_tcp_v0.2.2/Cargo.toml +++ /dev/null @@ -1,10 +0,0 @@ -[package] -name = "sync_tcp_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["sync"]} diff --git a/examples/sync_tcp_v0.2.2/src/main.rs b/examples/sync_tcp_v0.2.2/src/main.rs deleted file mode 100644 index 4abb3f1..0000000 --- a/examples/sync_tcp_v0.2.2/src/main.rs +++ /dev/null @@ -1,68 +0,0 @@ -use std::time::Duration; - -use mqrstt::{ - new_sync, packets::{self, Packet}, sync::NetworkStatus, ConnectOptions, EventHandler, MqttClient -}; - -pub struct PingPong { - pub client: MqttClient, -} - -impl EventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish_blocking(p.topic.clone(), p.qos, p.retain, "pong").unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -fn main() { - let client_id = "SyncTcp_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 1883; - - let (mut network, client) = new_sync(options); - - let stream = std::net::TcpStream::connect((address, port)).unwrap(); - stream.set_nonblocking(true).unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(stream, &mut pingpong).unwrap(); - - client.subscribe_blocking("mqrstt").unwrap(); - - let thread = std::thread::spawn(move || { - loop { - match network.poll(&mut pingpong) { - // The client is active but there is no data to be read - Ok(NetworkStatus::ActivePending) => std::thread::sleep(Duration::from_millis(100)), - // The client is active and there is data to be read - Ok(NetworkStatus::ActiveReady) => continue, - // The rest is an error - otherwise => return otherwise, - }; - } - }); - - std::thread::sleep(std::time::Duration::from_secs(30)); - client.disconnect_blocking().unwrap(); - - // Unwrap possible join errors on the thread. - let n = thread.join().unwrap(); - assert!(n.is_ok()); -} diff --git a/examples/tokio_tcp_v0.2.2/Cargo.toml b/examples/tokio_tcp_v0.2.2/Cargo.toml deleted file mode 100644 index 6a5954e..0000000 --- a/examples/tokio_tcp_v0.2.2/Cargo.toml +++ /dev/null @@ -1,18 +0,0 @@ -[package] -name = "tokio_tcp_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["tokio"]} - -tokio = { version = "1.26.0", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } -tokio-rustls = "0.24.0" - -async-trait = "0.1.68" - -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } diff --git a/examples/tokio_tcp_v0.2.2/src/main.rs b/examples/tokio_tcp_v0.2.2/src/main.rs deleted file mode 100644 index 4e98b94..0000000 --- a/examples/tokio_tcp_v0.2.2/src/main.rs +++ /dev/null @@ -1,69 +0,0 @@ -use std::time::Duration; - -use async_trait::async_trait; -use mqrstt::{ - new_tokio, - packets::{self, Packet}, - tokio::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -#[tokio::main] -async fn main() { - let client_id = "TokioTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_tokio(options); - - let stream = tokio::net::TcpStream::connect((address, port)).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(stream, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = tokio::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - tokio::time::sleep(Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); -} diff --git a/examples/tokio_tls_v0.2.2/Cargo.toml b/examples/tokio_tls_v0.2.2/Cargo.toml deleted file mode 100644 index 2890bf9..0000000 --- a/examples/tokio_tls_v0.2.2/Cargo.toml +++ /dev/null @@ -1,19 +0,0 @@ -[package] -name = "tokio_tls_v0_2_2" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -mqrstt = { version = "0.2.2", default-features = false, features = ["tokio"]} - -tokio = { version = "1.26.0", features = ["rt-multi-thread", "rt", "macros", "sync", "io-util", "net", "time"] } -tokio-rustls = "0.24.0" - -async-trait = "0.1.68" - -rustls = { version = "0.20.7" } -rustls-pemfile = { version = "1.0.1" } -webpki = { version = "0.22.0" } diff --git a/examples/tokio_tls_v0.2.2/src/broker.emqx.io-ca.crt b/examples/tokio_tls_v0.2.2/src/broker.emqx.io-ca.crt deleted file mode 100644 index fd4341d..0000000 --- a/examples/tokio_tls_v0.2.2/src/broker.emqx.io-ca.crt +++ /dev/null @@ -1,22 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDrzCCApegAwIBAgIQCDvgVpBCRrGhdWrJWZHHSjANBgkqhkiG9w0BAQUFADBh -MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 -d3cuZGlnaWNlcnQuY29tMSAwHgYDVQQDExdEaWdpQ2VydCBHbG9iYWwgUm9vdCBD -QTAeFw0wNjExMTAwMDAwMDBaFw0zMTExMTAwMDAwMDBaMGExCzAJBgNVBAYTAlVT -MRUwEwYDVQQKEwxEaWdpQ2VydCBJbmMxGTAXBgNVBAsTEHd3dy5kaWdpY2VydC5j -b20xIDAeBgNVBAMTF0RpZ2lDZXJ0IEdsb2JhbCBSb290IENBMIIBIjANBgkqhkiG -9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4jvhEXLeqKTTo1eqUKKPC3eQyaKl7hLOllsB -CSDMAZOnTjC3U/dDxGkAV53ijSLdhwZAAIEJzs4bg7/fzTtxRuLWZscFs3YnFo97 -nh6Vfe63SKMI2tavegw5BmV/Sl0fvBf4q77uKNd0f3p4mVmFaG5cIzJLv07A6Fpt -43C/dxC//AH2hdmoRBBYMql1GNXRor5H4idq9Joz+EkIYIvUX7Q6hL+hqkpMfT7P -T19sdl6gSzeRntwi5m3OFBqOasv+zbMUZBfHWymeMr/y7vrTC0LUq7dBMtoM1O/4 -gdW7jVg/tRvoSSiicNoxBN33shbyTApOB6jtSj1etX+jkMOvJwIDAQABo2MwYTAO -BgNVHQ8BAf8EBAMCAYYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUA95QNVbR -TLtm8KPiGxvDl7I90VUwHwYDVR0jBBgwFoAUA95QNVbRTLtm8KPiGxvDl7I90VUw -DQYJKoZIhvcNAQEFBQADggEBAMucN6pIExIK+t1EnE9SsPTfrgT1eXkIoyQY/Esr -hMAtudXH/vTBH1jLuG2cenTnmCmrEbXjcKChzUyImZOMkXDiqw8cvpOp/2PV5Adg -06O/nVsJ8dWO41P0jmP6P6fbtGbfYmbW0W5BjfIttep3Sp+dWOIrWcBAI+0tKIJF -PnlUkiaY4IBIqDfv8NZ5YBberOgOzW6sRBc4L0na4UU+Krk2U886UAb3LujEV0ls -YSEY1QSteDwsOoBrp+uvFRTp2InBuThs4pFsiv9kuXclVzDAGySj4dzp30d8tbQk -CAUw7C29C79Fv1C5qfPrmAESrciIxpg0X40KPMbp1ZWVbd4= ------END CERTIFICATE----- diff --git a/examples/tokio_tls_v0.2.2/src/main.rs b/examples/tokio_tls_v0.2.2/src/main.rs deleted file mode 100644 index 0cc1da9..0000000 --- a/examples/tokio_tls_v0.2.2/src/main.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::{ - io::{BufReader, Cursor}, - sync::Arc, - time::Duration, -}; - -use async_trait::async_trait; -use mqrstt::{ - new_tokio, - packets::{self, Packet}, - tokio::NetworkStatus, - AsyncEventHandler, ConnectOptions, MqttClient, -}; -use tokio_rustls::rustls::{Certificate, ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName}; - -pub const EMQX_CERT: &[u8] = include_bytes!("broker.emqx.io-ca.crt"); - -pub struct PingPong { - pub client: MqttClient, -} - -#[async_trait] -impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, "pong").await.unwrap(); - println!("Received Ping, Send pong!"); - } - } - } - Packet::ConnAck(_) => { - println!("Connected!") - } - _ => (), - } - } -} - -#[derive(Debug, Clone)] -pub enum PrivateKey { - RSA(Vec), - ECC(Vec), -} - -pub fn simple_rust_tls(ca: Vec, alpn: Option>>, client_auth: Option<(Vec, PrivateKey)>) -> Result, rustls::Error> { - let mut root_cert_store = RootCertStore::empty(); - - let ca_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(ca))).unwrap(); - - let trust_anchors = ca_certs.iter().map_while(|cert| { - if let Ok(ta) = webpki::TrustAnchor::try_from_cert_der(&cert[..]) { - Some(OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)) - } else { - None - } - }); - root_cert_store.add_server_trust_anchors(trust_anchors); - - assert!(!root_cert_store.is_empty()); - - let config = ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store); - - let mut config = match client_auth { - Some((client_cert_info, client_private_info)) => { - let read_private_keys = match client_private_info { - PrivateKey::RSA(rsa) => rustls_pemfile::rsa_private_keys(&mut BufReader::new(Cursor::new(rsa))), - PrivateKey::ECC(ecc) => rustls_pemfile::pkcs8_private_keys(&mut BufReader::new(Cursor::new(ecc))), - } - .unwrap(); - - let key = read_private_keys.into_iter().next().unwrap(); - - let client_certs = rustls_pemfile::certs(&mut BufReader::new(Cursor::new(client_cert_info))).unwrap(); - let client_cert_chain = client_certs.into_iter().map(Certificate).collect(); - - config.with_single_cert(client_cert_chain, tokio_rustls::rustls::PrivateKey(key)).unwrap() - } - None => config.with_no_client_auth(), - }; - - if let Some(alpn) = alpn { - config.alpn_protocols.extend(alpn) - } - - Ok(Arc::new(config)) -} - -#[tokio::main] -async fn main() { - let client_id = "TokioTls_MQrsTT_Example".to_string(); - let options = ConnectOptions::new(client_id); - - let address = "broker.emqx.io"; - let port = 8883; - - let (mut network, client) = new_tokio(options); - - let arc_client_config = simple_rust_tls(EMQX_CERT.to_vec(), None, None).unwrap(); - - let domain = ServerName::try_from(address).unwrap(); - let connector = tokio_rustls::TlsConnector::from(arc_client_config); - - let stream = tokio::net::TcpStream::connect((address, port)).await.unwrap(); - let connection = connector.connect(domain, stream).await.unwrap(); - - let mut pingpong = PingPong { client: client.clone() }; - - network.connect(connection, &mut pingpong).await.unwrap(); - - client.subscribe("mqrstt").await.unwrap(); - - let (n, _) = tokio::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - tokio::time::sleep(Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); -} From 6554d4b8e24a0f4721dc8ed508e831daa88cbc20 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:01:03 +0100 Subject: [PATCH 06/45] covert properties and reason codes to macro --- .gitignore | 4 +- Cargo.toml | 1 - mqrstt/Cargo.toml | 17 +- mqrstt/benches/benchmarks/mod.rs | 21 +- mqrstt/src/available_packet_ids.rs | 4 +- mqrstt/src/client.rs | 37 +- mqrstt/src/connect_options.rs | 17 +- mqrstt/src/error.rs | 2 +- mqrstt/src/packets/auth.rs | 149 --- mqrstt/src/packets/auth/mod.rs | 52 + mqrstt/src/packets/auth/properties.rs | 129 +++ mqrstt/src/packets/auth/reason_code.rs | 5 + mqrstt/src/packets/connack/mod.rs | 238 +++++ .../{connack.rs => connack/properties.rs} | 942 ++++++------------ mqrstt/src/packets/connack/reason_code.rs | 24 + mqrstt/src/packets/connect.rs | 892 ----------------- mqrstt/src/packets/connect/connect_flags.rs | 90 ++ .../src/packets/connect/connect_properties.rs | 236 +++++ mqrstt/src/packets/connect/last_will.rs | 83 ++ .../packets/connect/last_will_properties.rs | 167 ++++ mqrstt/src/packets/connect/mod.rs | 496 +++++++++ mqrstt/src/packets/disconnect.rs | 226 ----- mqrstt/src/packets/disconnect/mod.rs | 178 ++++ mqrstt/src/packets/disconnect/properties.rs | 110 ++ mqrstt/src/packets/disconnect/reason_code.rs | 148 +++ mqrstt/src/packets/macros/mod.rs | 8 + .../src/packets/macros/properties_macros.rs | 140 ++- .../src/packets/macros/reason_code_macros.rs | 113 ++- mqrstt/src/packets/mod.rs | 688 +------------ mqrstt/src/packets/mqtt_trait/mod.rs | 106 ++ .../src/packets/mqtt_trait/primitive_impl.rs | 269 +++++ mqrstt/src/packets/mqtt_traits.rs | 50 - mqrstt/src/packets/primitive/fixed_header.rs | 40 + mqrstt/src/packets/primitive/mod.rs | 13 + mqrstt/src/packets/primitive/property_type.rs | 140 +++ .../src/packets/primitive/protocol_version.rs | 43 + mqrstt/src/packets/primitive/qos.rs | 71 ++ .../src/packets/primitive/variable_integer.rs | 182 ++++ .../src/packets/{puback.rs => puback/mod.rs} | 98 +- mqrstt/src/packets/puback/reason_code.rs | 67 ++ .../packets/{pubcomp.rs => pubcomp/mod.rs} | 131 +-- mqrstt/src/packets/pubcomp/properties.rs | 89 ++ mqrstt/src/packets/pubcomp/reason_code.rs | 5 + mqrstt/src/packets/publish/mod.rs | 202 ++++ .../{publish.rs => publish/properties.rs} | 587 ++++------- .../src/packets/{pubrec.rs => pubrec/mod.rs} | 99 +- mqrstt/src/packets/pubrec/properties.rs | 84 ++ mqrstt/src/packets/pubrec/reason_code.rs | 67 ++ .../src/packets/{pubrel.rs => pubrel/mod.rs} | 147 +-- mqrstt/src/packets/pubrel/properties.rs | 122 +++ mqrstt/src/packets/pubrel/reason_code.rs | 49 + mqrstt/src/packets/reason_codes.rs | 510 ++-------- mqrstt/src/packets/suback.rs | 160 --- mqrstt/src/packets/suback/mod.rs | 116 +++ mqrstt/src/packets/suback/properties.rs | 92 ++ mqrstt/src/packets/suback/reason_code.rs | 61 ++ mqrstt/src/packets/subscribe.rs | 18 +- mqrstt/src/packets/subscribe/properties.rs | 0 mqrstt/src/packets/subscribe/reason_code.rs | 0 mqrstt/src/packets/unsuback.rs | 149 --- mqrstt/src/packets/unsuback/mod.rs | 80 ++ mqrstt/src/packets/unsuback/properties.rs | 82 ++ mqrstt/src/packets/unsuback/reason_code.rs | 58 ++ .../{unsubscribe.rs => unsubscribe/mod.rs} | 104 +- mqrstt/src/packets/unsubscribe/properties.rs | 69 ++ mqrstt/src/smol/network.rs | 3 +- mqrstt/src/smol/stream.rs | 2 +- mqrstt/src/state_handler.rs | 13 +- mqrstt/src/tests/test_bytes.rs | 2 +- mqrstt/src/tests/test_packets.rs | 13 +- mqrstt/src/tokio/network.rs | 2 +- mqrstt/src/tokio/stream/mod.rs | 2 +- 72 files changed, 5156 insertions(+), 4258 deletions(-) delete mode 100644 mqrstt/src/packets/auth.rs create mode 100644 mqrstt/src/packets/auth/mod.rs create mode 100644 mqrstt/src/packets/auth/properties.rs create mode 100644 mqrstt/src/packets/auth/reason_code.rs create mode 100644 mqrstt/src/packets/connack/mod.rs rename mqrstt/src/packets/{connack.rs => connack/properties.rs} (56%) create mode 100644 mqrstt/src/packets/connack/reason_code.rs delete mode 100644 mqrstt/src/packets/connect.rs create mode 100644 mqrstt/src/packets/connect/connect_flags.rs create mode 100644 mqrstt/src/packets/connect/connect_properties.rs create mode 100644 mqrstt/src/packets/connect/last_will.rs create mode 100644 mqrstt/src/packets/connect/last_will_properties.rs create mode 100644 mqrstt/src/packets/connect/mod.rs delete mode 100644 mqrstt/src/packets/disconnect.rs create mode 100644 mqrstt/src/packets/disconnect/mod.rs create mode 100644 mqrstt/src/packets/disconnect/properties.rs create mode 100644 mqrstt/src/packets/disconnect/reason_code.rs create mode 100644 mqrstt/src/packets/mqtt_trait/mod.rs create mode 100644 mqrstt/src/packets/mqtt_trait/primitive_impl.rs delete mode 100644 mqrstt/src/packets/mqtt_traits.rs create mode 100644 mqrstt/src/packets/primitive/fixed_header.rs create mode 100644 mqrstt/src/packets/primitive/mod.rs create mode 100644 mqrstt/src/packets/primitive/property_type.rs create mode 100644 mqrstt/src/packets/primitive/protocol_version.rs create mode 100644 mqrstt/src/packets/primitive/qos.rs create mode 100644 mqrstt/src/packets/primitive/variable_integer.rs rename mqrstt/src/packets/{puback.rs => puback/mod.rs} (75%) create mode 100644 mqrstt/src/packets/puback/reason_code.rs rename mqrstt/src/packets/{pubcomp.rs => pubcomp/mod.rs} (68%) create mode 100644 mqrstt/src/packets/pubcomp/properties.rs create mode 100644 mqrstt/src/packets/pubcomp/reason_code.rs create mode 100644 mqrstt/src/packets/publish/mod.rs rename mqrstt/src/packets/{publish.rs => publish/properties.rs} (54%) rename mqrstt/src/packets/{pubrec.rs => pubrec/mod.rs} (68%) create mode 100644 mqrstt/src/packets/pubrec/properties.rs create mode 100644 mqrstt/src/packets/pubrec/reason_code.rs rename mqrstt/src/packets/{pubrel.rs => pubrel/mod.rs} (70%) create mode 100644 mqrstt/src/packets/pubrel/properties.rs create mode 100644 mqrstt/src/packets/pubrel/reason_code.rs delete mode 100644 mqrstt/src/packets/suback.rs create mode 100644 mqrstt/src/packets/suback/mod.rs create mode 100644 mqrstt/src/packets/suback/properties.rs create mode 100644 mqrstt/src/packets/suback/reason_code.rs create mode 100644 mqrstt/src/packets/subscribe/properties.rs create mode 100644 mqrstt/src/packets/subscribe/reason_code.rs delete mode 100644 mqrstt/src/packets/unsuback.rs create mode 100644 mqrstt/src/packets/unsuback/mod.rs create mode 100644 mqrstt/src/packets/unsuback/properties.rs create mode 100644 mqrstt/src/packets/unsuback/reason_code.rs rename mqrstt/src/packets/{unsubscribe.rs => unsubscribe/mod.rs} (76%) create mode 100644 mqrstt/src/packets/unsubscribe/properties.rs diff --git a/.gitignore b/.gitignore index 9b5ed2a..0b6e4ba 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ **/target examples/tokio_tls/Cargo.lock examples/smol_tls/Cargo.lock -.vscode/** \ No newline at end of file +.vscode/** + +Cargo.lock \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 8c39591..56f4f9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,5 +2,4 @@ members = [ "mqrstt", - "examples/*", ] \ No newline at end of file diff --git a/mqrstt/Cargo.toml b/mqrstt/Cargo.toml index ba4a6d0..b433692 100644 --- a/mqrstt/Cargo.toml +++ b/mqrstt/Cargo.toml @@ -28,24 +28,21 @@ test = [] [dependencies] # Packets -bytes = "1.5.0" +bytes = "1" # Errors -thiserror = "1.0.53" -tracing = { version = "0.1.40", optional = true } +thiserror = "1" +tracing = { version = "0.1", optional = true } -async-channel = "2.1.1" -#async-mutex = "1.4.0" -futures = { version = "0.3.30", default-features = false, features = [ +async-channel = "2" +futures = { version = "0.3", default-features = false, features = [ "std", "async-await", ] } -# quic feature flag -# quinn = {version = "0.9.0", optional = true } # tokio feature flag -tokio = { version = "1.35.1", features = [ +tokio = { version = "1", features = [ "macros", "io-util", "net", @@ -53,7 +50,7 @@ tokio = { version = "1.35.1", features = [ ], optional = true } # smol feature flag -smol = { version = "2.0.0", optional = true } +smol = { version = "2", optional = true } [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } diff --git a/mqrstt/benches/benchmarks/mod.rs b/mqrstt/benches/benchmarks/mod.rs index 6a66044..3b5244f 100644 --- a/mqrstt/benches/benchmarks/mod.rs +++ b/mqrstt/benches/benchmarks/mod.rs @@ -4,20 +4,21 @@ use mqrstt::packets::{Disconnect, Packet, Publish}; pub mod tokio; fn fill_stuff(buffer: &mut BytesMut, publ_count: usize, publ_size: usize) { - empty_connect(buffer); - for i in 0..publ_count { - very_large_publish(i as u16, publ_size / 5).write(buffer).unwrap(); - } - empty_disconnect().write(buffer).unwrap(); + // empty_connect(buffer); + // for i in 0..publ_count { + // very_large_publish(i as u16, publ_size / 5).write(buffer).unwrap(); + // } + // empty_disconnect().write(buffer).unwrap(); } fn empty_disconnect() -> Packet { - let discon = Disconnect { - reason_code: mqrstt::packets::reason_codes::DisconnectReasonCode::ServerBusy, - properties: Default::default(), - }; + // let discon = Disconnect { + // reason_code: mqrstt::packets::reason_codes::DisconnectReasonCode::ServerBusy, + // properties: Default::default(), + // }; - Packet::Disconnect(discon) + // Packet::Disconnect(discon) + todo!() } fn empty_connect(buffer: &mut BytesMut) { diff --git a/mqrstt/src/available_packet_ids.rs b/mqrstt/src/available_packet_ids.rs index 1527d8e..6a0e28b 100644 --- a/mqrstt/src/available_packet_ids.rs +++ b/mqrstt/src/available_packet_ids.rs @@ -6,7 +6,7 @@ use tracing::{debug, error}; use crate::error::HandlerError; #[derive(Debug, Clone)] -pub struct AvailablePacketIds { +pub(crate) struct AvailablePacketIds { sender: Sender, } @@ -22,7 +22,7 @@ impl AvailablePacketIds { (apkid, r) } - pub fn mark_available(&self, pkid: u16) -> Result<(), HandlerError> { + pub(crate) fn mark_available(&self, pkid: u16) -> Result<(), HandlerError> { match self.sender.try_send(pkid) { Ok(_) => { #[cfg(feature = "logs")] diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index d886114..f447375 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -7,9 +7,18 @@ use tracing::info; use crate::{ error::ClientError, packets::{ - mqtt_traits::PacketValidation, - reason_codes::DisconnectReasonCode, - Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties}, {Subscribe, SubscribeProperties, Subscription}, {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, + mqtt_trait::PacketValidation, + DisconnectReasonCode, + Packet, QoS, + // disconnect::{Disconnect, DisconnectProperties}, + // publish::{Publish, PublishProperties}, + // subscribe::{Subscribe, SubscribeProperties, Subscription}, + // unsubscribe::{Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, + + {Disconnect, DisconnectProperties}, + {Publish, PublishProperties}, + {Subscribe, SubscribeProperties, Subscription}, + {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, }, }; @@ -379,7 +388,7 @@ impl MqttClient { /// use mqrstt::packets::UnsubscribeProperties; /// /// let properties = UnsubscribeProperties{ - /// user_properties: vec![("property".to_string(), "value".to_string())], + /// user_properties: vec![("property".into(), "value".into())], /// }; /// /// // Unsubscribe from a single topic specified as a string: @@ -387,7 +396,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topic, properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from multiple topics specified as an array of string slices: @@ -395,7 +404,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topics.as_slice(), properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from a single topic specified as a String: @@ -403,7 +412,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topic, properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from multiple topics specified as a Vec: @@ -411,7 +420,7 @@ impl MqttClient { /// mqtt_client.unsubscribe_with_properties(topics, properties).await; /// /// # let properties = UnsubscribeProperties{ - /// # user_properties: vec![("property".to_string(), "value".to_string())], + /// # user_properties: vec![("property".into(), "value".into())], /// # }; /// /// // Unsubscribe from multiple topics specified as an array of String: @@ -472,7 +481,7 @@ impl MqttClient { /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; - /// use mqrstt::packets::reason_codes::DisconnectReasonCode; + /// use mqrstt::packets::DisconnectReasonCode; /// /// let properties = DisconnectProperties { /// reason_string: Some("Reason here".into()), @@ -788,7 +797,7 @@ impl MqttClient { /// use mqrstt::packets::UnsubscribeProperties; /// /// let properties = UnsubscribeProperties{ - /// user_properties: vec![("property".to_string(), "value".to_string())], + /// user_properties: vec![("property".into(), "value".into())], /// }; /// # let properties_clone = properties.clone(); /// @@ -873,7 +882,7 @@ impl MqttClient { /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; - /// use mqrstt::packets::reason_codes::DisconnectReasonCode; + /// use mqrstt::packets::DisconnectReasonCode; /// /// let properties = DisconnectProperties { /// reason_string: Some("Reason here".into()), @@ -898,7 +907,7 @@ mod tests { use crate::{ error::{ClientError, PacketValidationError}, - packets::{reason_codes::DisconnectReasonCode, DisconnectProperties, Packet, PacketType, Publish, QoS, Subscribe, SubscribeProperties, UnsubscribeProperties}, + packets::{DisconnectProperties, DisconnectReasonCode, Packet, PacketType, Publish, QoS, Subscribe, SubscribeProperties, UnsubscribeProperties}, }; use super::MqttClient; @@ -1047,7 +1056,7 @@ mod tests { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); let properties = UnsubscribeProperties{ - user_properties: vec![("property".to_string(), "value".to_string())], + user_properties: vec![("property".into(), "value".into())], }; // Unsubscribe from a single topic specified as a string: @@ -1193,7 +1202,7 @@ mod tests { let (client, client_to_handler_r, _) = create_new_test_client(); let prop = UnsubscribeProperties { - user_properties: vec![("A".to_string(), "B".to_string())], + user_properties: vec![("A".into(), "B".into())], }; client.unsubscribe_with_properties("Topic", prop.clone()).await.unwrap(); diff --git a/mqrstt/src/connect_options.rs b/mqrstt/src/connect_options.rs index 15af570..f83164b 100644 --- a/mqrstt/src/connect_options.rs +++ b/mqrstt/src/connect_options.rs @@ -41,7 +41,7 @@ pub struct ConnectOptions { request_problem_information: Option, user_properties: Vec<(Box, Box)>, authentication_method: Option>, - authentication_data: Bytes, + authentication_data: Option>, /// Last will that will be issued on unexpected disconnect last_will: Option, @@ -62,9 +62,9 @@ impl Default for ConnectOptions { topic_alias_maximum: None, request_response_information: None, request_problem_information: None, - user_properties: vec![], + user_properties: Vec::new(), authentication_method: None, - authentication_data: Bytes::new(), + authentication_data: None, last_will: None, } } @@ -72,8 +72,11 @@ impl Default for ConnectOptions { impl ConnectOptions { /// Create a new [`ConnectOptions`] - /// ClientId recommendation: - /// - 1 to 23 bytes UTF-8 bytes + /// + /// Be aware: + /// This client does not restrict the client identifier in any way. However, the MQTT v5.0 specification does. + /// It is thus recommended to use a client id that is compatible with the MQTT v5.0 specification. + /// - 1 to 23 bytes UTF-8 bytes. /// - Contains [a-zA-Z0-9] characters only. /// /// Some brokers accept longer client ids with different characters @@ -94,7 +97,7 @@ impl ConnectOptions { request_problem_information: None, user_properties: vec![], authentication_method: None, - authentication_data: Bytes::new(), + authentication_data: None, last_will: None, } } @@ -119,7 +122,7 @@ impl ConnectOptions { username: self.username.clone(), password: self.password.clone(), connect_properties, - protocol_version: crate::packets::ProtocolVersion::V5, + protocol_version: crate::packets::protocol_version::ProtocolVersion::V5, last_will: self.last_will.clone(), }; diff --git a/mqrstt/src/error.rs b/mqrstt/src/error.rs index 6475b33..7b6cac7 100644 --- a/mqrstt/src/error.rs +++ b/mqrstt/src/error.rs @@ -4,7 +4,7 @@ use async_channel::{RecvError, SendError}; use crate::packets::{ error::{DeserializeError, ReadBytes, SerializeError}, - reason_codes::ConnAckReasonCode, + ConnAckReasonCode, {Packet, PacketType}, }; diff --git a/mqrstt/src/packets/auth.rs b/mqrstt/src/packets/auth.rs deleted file mode 100644 index 0da56c0..0000000 --- a/mqrstt/src/packets/auth.rs +++ /dev/null @@ -1,149 +0,0 @@ -use bytes::Bytes; - -use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - read_variable_integer, - reason_codes::AuthReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Auth { - pub reason_code: AuthReasonCode, - pub properties: AuthProperties, -} - -impl PacketRead for Auth { - fn read(_: u8, _: usize, mut buf: Bytes) -> Result { - let reason_code = AuthReasonCode::read(&mut buf)?; - let properties = AuthProperties::read(&mut buf)?; - - Ok(Self { reason_code, properties }) - } -} - -impl PacketWrite for Auth { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - self.reason_code.write(buf)?; - self.properties.write(buf)?; - Ok(()) - } -} - -impl WireLength for Auth { - fn wire_len(&self) -> usize { - 1 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len() - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct AuthProperties { - /// 3.15.2.2.2 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method. - pub authentication_method: Option>, - - /// 3.15.2.2.3 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - pub authentication_data: Bytes, - - /// 3.15.2.2.4 Reason String - /// 31 (0x1F) Byte, Identifier of the Reason String - pub reason_string: Option>, - - /// 3.15.2.2.5 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for AuthProperties { - fn read(buf: &mut Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = AuthProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::MalformedPacket); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut property_data)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.reason_string = Some(Box::::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::AuthenticationMethod => { - if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); - } - properties.authentication_method = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationData => { - if properties.authentication_data.is_empty() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); - } - properties.authentication_data = Bytes::read(&mut property_data)?; - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Auth)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for AuthProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(authentication_method) = &self.authentication_method { - PropertyType::AuthenticationMethod.write(buf)?; - authentication_method.write(buf)?; - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - PropertyType::AuthenticationData.write(buf)?; - self.authentication_data.write(buf)?; - } - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - - Ok(()) - } -} - -impl WireLength for AuthProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(authentication_method) = &self.authentication_method { - len += authentication_method.wire_len(); - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - len += self.authentication_data.wire_len(); - } - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len(); - } - for (key, value) in &self.user_properties { - len += key.wire_len() + value.wire_len(); - } - len - } -} diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs new file mode 100644 index 0000000..ad4905c --- /dev/null +++ b/mqrstt/src/packets/auth/mod.rs @@ -0,0 +1,52 @@ +mod properties; +pub use properties::AuthProperties; +mod reason_code; +pub use reason_code::AuthReasonCode; + +use bytes::Bytes; + +use super::{mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, VariableInteger}; + + +#[derive(Debug, Clone, PartialEq, Eq)] +/// The AUTH packet is used to perform more intriquite authentication methods. +/// +/// At the time of writing this client does not (yet) provide the user a method of handling the auth handshake. +/// There are several other ways to perform authentication, for example using TLS. +/// Additionally, not many clients support this packet fully. +pub struct Auth { + pub reason_code: AuthReasonCode, + pub properties: AuthProperties, +} + +impl PacketRead for Auth { + fn read(_: u8, _: usize, mut buf: Bytes) -> Result { + let reason_code = AuthReasonCode::read(&mut buf)?; + let properties = AuthProperties::read(&mut buf)?; + + Ok(Self { reason_code, properties }) + } +} + +impl PacketAsyncRead for Auth where S: tokio::io::AsyncReadExt + Unpin { + async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let (reason_code, reason_code_read_bytes) = AuthReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = AuthProperties::async_read(stream).await?; + + Ok((Self { reason_code, properties }, reason_code_read_bytes + properties_read_bytes )) + } +} + +impl PacketWrite for Auth { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + self.reason_code.write(buf)?; + self.properties.write(buf)?; + Ok(()) + } +} + +impl WireLength for Auth { + fn wire_len(&self) -> usize { + 1 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/auth/properties.rs b/mqrstt/src/packets/auth/properties.rs new file mode 100644 index 0000000..259b11d --- /dev/null +++ b/mqrstt/src/packets/auth/properties.rs @@ -0,0 +1,129 @@ +use bytes::Bytes; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite,WireLength}, PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!( + AuthProperties, + AuthenticationMethod, + AuthenticationData, + ReasonString, + UserProperty +); + +// #[derive(Debug, Default, PartialEq, Eq, Clone)] +// pub struct AuthProperties { +// /// 3.15.2.2.2 Authentication Method +// /// 21 (0x15) Byte, Identifier of the Authentication Method. +// pub authentication_method: Option>, + +// /// 3.15.2.2.3 Authentication Data +// /// 22 (0x16) Byte, Identifier of the Authentication Data +// pub authentication_data: Vec, + +// /// 3.15.2.2.4 Reason String +// /// 31 (0x1F) Byte, Identifier of the Reason String +// pub reason_string: Option>, + +// /// 3.15.2.2.5 User Property +// /// 38 (0x26) Byte, Identifier of the User Property. +// pub user_properties: Vec<(Box, Box)>, +// } + +impl MqttRead for AuthProperties { + fn read(buf: &mut Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = AuthProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::MalformedPacket); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut property_data)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.reason_string = Some(Box::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::AuthenticationMethod => { + if properties.authentication_method.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); + } + properties.authentication_method = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationData => { + if properties.authentication_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); + } + properties.authentication_data = Some(Vec::::read(&mut property_data)?); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Auth)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for AuthProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(authentication_method) = &self.authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if let Some(authentication_data) = &self.authentication_data { + if !authentication_data.is_empty() && self.authentication_method.is_some() { + PropertyType::AuthenticationData.write(buf)?; + authentication_data.write(buf)?; + } + } + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + + Ok(()) + } +} + +// impl WireLength for AuthProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(authentication_method) = &self.authentication_method { +// len += 1 + authentication_method.wire_len(); +// } +// if let Some(authentication_data) = self.authentication_data { +// if !authentication_data.is_empty() && self.authentication_method.is_some() { +// len += 1 + authentication_data.wire_len(); +// } +// } +// if let Some(reason_string) = &self.reason_string { +// len += 1 + reason_string.wire_len(); +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/auth/reason_code.rs b/mqrstt/src/packets/auth/reason_code.rs new file mode 100644 index 0000000..364e1aa --- /dev/null +++ b/mqrstt/src/packets/auth/reason_code.rs @@ -0,0 +1,5 @@ +crate::packets::macros::reason_code!(AuthReasonCode, + Success, + ContinueAuthentication, + ReAuthenticate +); \ No newline at end of file diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs new file mode 100644 index 0000000..1e8e1ce --- /dev/null +++ b/mqrstt/src/packets/connack/mod.rs @@ -0,0 +1,238 @@ +mod properties; +pub use properties::ConnAckProperties; + +mod reason_code; +pub use reason_code::ConnAckReasonCode; + + +use super::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, QoS, VariableInteger, +}; +use bytes::{Buf, BufMut}; + + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct ConnAck { + /// 3.2.2.1 Connect Acknowledge Flags + pub connack_flags: ConnAckFlags, + + /// 3.2.2.2 Connect Reason Code + /// Byte 2 in the Variable Header is the Connect Reason Code. + pub reason_code: ConnAckReasonCode, + + /// 3.2.2.3 CONNACK Properties + pub connack_properties: ConnAckProperties, +} + +impl PacketRead for ConnAck { + fn read(_: u8, header_len: usize, mut buf: bytes::Bytes) -> Result { + if header_len > buf.len() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), header_len)); + } + + let connack_flags = ConnAckFlags::read(&mut buf)?; + let reason_code = ConnAckReasonCode::read(&mut buf)?; + let connack_properties = ConnAckProperties::read(&mut buf)?; + + Ok(Self { + connack_flags, + reason_code, + connack_properties, + }) + } +} + +impl PacketAsyncRead for ConnAck where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; + let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; + let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; + + Ok(( + Self { + connack_flags, + reason_code, + connack_properties, + }, + read_bytes + reason_code_read_bytes + connack_properties_read_bytes + )) + + } + } +} + +impl PacketWrite for ConnAck { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + self.connack_flags.write(buf)?; + self.reason_code.write(buf)?; + self.connack_properties.write(buf)?; + + Ok(()) + } +} + +impl WireLength for ConnAck { + fn wire_len(&self) -> usize { + 2 + // 1 for connack_flags and 1 for reason_code + self.connack_properties.wire_len().variable_integer_len() + + self.connack_properties.wire_len() + } +} + + +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct ConnAckFlags { + pub session_present: bool, +} + +impl MqttAsyncRead for ConnAckFlags where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(stream: &mut S) -> impl std::future::Future> { + async move { + let byte = stream.read_u8().await?; + Ok((Self { + session_present: (byte & 0b00000001) == 0b00000001, + }, 1)) + } + } +} + +impl MqttRead for ConnAckFlags { + fn read(buf: &mut bytes::Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + let byte = buf.get_u8(); + + Ok(Self { + session_present: (byte & 0b00000001) == 0b00000001, + }) + } +} + +impl MqttWrite for ConnAckFlags { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + let byte = self.session_present as u8; + + buf.put_u8(byte); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + + use crate::packets::{ + connack::{ConnAck, ConnAckProperties}, mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, ConnAckReasonCode, Packet, VariableInteger + }; + + #[test] + fn test_wire_len() { + let mut buf = bytes::BytesMut::new(); + + let connack_properties = ConnAckProperties { + session_expiry_interval: Some(60), // Session expiry interval in seconds + receive_maximum: Some(20), // Maximum number of QoS 1 and QoS 2 publications that the client is willing to process concurrently + maximum_qos: Some(crate::packets::QoS::AtMostOnce), // Maximum QoS level supported by the server + retain_available: Some(true), // Whether the server supports retained messages + maximum_packet_size: Some(1024), // Maximum packet size the server is willing to accept + assigned_client_id: Some(Box::from("client-12345")), // Client identifier assigned by the server + topic_alias_maximum: Some(10), // Maximum number of topic aliases supported by the server + reason_string: Some(Box::from("Connection accepted")), // Reason string for the connection acknowledgment + user_properties: vec![(Box::from("key1"), Box::from("value1"))], // User property key-value pair + wildcards_available: Some(true), // Whether wildcard subscriptions are available + subscription_ids_available: Some(true), // Whether subscription identifiers are available + shared_subscription_available: Some(true), // Whether shared subscriptions are available + server_keep_alive: Some(120), // Server keep alive time in seconds + response_info: Some(Box::from("Response info")), // Response information + server_reference: Some(Box::from("server-reference")), // Server reference + authentication_method: Some(Box::from("auth-method")), // Authentication method + authentication_data: Some(vec![1, 2, 3, 4]), // Authentication data + }; + + let len = connack_properties.wire_len(); + // determine length of variable integer + let len_of_wire_len = len.write_variable_integer(&mut buf).unwrap(); + // clear buffer before writing actual properties + buf.clear(); + connack_properties.write(&mut buf).unwrap(); + + assert_eq!(len + len_of_wire_len, buf.len()); + + } + + #[test] + fn read_write_connack_packet() { + let c = ConnAck { ..Default::default() }; + + let p1 = Packet::ConnAck(c); + let mut buf = bytes::BytesMut::new(); + + p1.write(&mut buf).unwrap(); + + let p2 = Packet::read_from_buffer(&mut buf).unwrap(); + + assert_eq!(p1, p2); + } + + #[test] + fn read_write_connack() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + 0x01, // Connack flags + 0x00, // Reason code, + 0x00, // empty properties + ]; + + buf.extend_from_slice(packet); + let c1 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); + + assert_eq!(ConnAckReasonCode::Success, c1.reason_code); + assert_eq!(ConnAckProperties::default(), c1.connack_properties); + + let mut buf = bytes::BytesMut::new(); + + c1.write(&mut buf).unwrap(); + + let c2 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); + + assert_eq!(c1, c2) + } + + #[test] + fn read_write_connack_properties() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + 56, // ConnAckProperties variable length + 17, // session_expiry_interval + 0xff, 0xff, 37, // retain_available + 0x1, // true + 18, // Assigned Client Id + 0, 11, // 11 bytes "KeanuReeves" without space + b'K', b'e', b'a', b'n', b'u', b'R', b'e', b'e', b'v', b'e', b's', 36, // Max QoS + 2, // QoS 2 Exactly Once + 34, // Topic Alias Max = 255 + 0, 255, 31, // Reason String = 'Houston we have got a problem' + 0, 29, b'H', b'o', b'u', b's', b't', b'o', b'n', b' ', b'w', b'e', b' ', b'h', b'a', b'v', b'e', b' ', b'g', b'o', b't', b' ', b'a', b' ', b'p', b'r', b'o', b'b', b'l', b'e', b'm', + ]; + + buf.extend_from_slice(packet); + let c1 = ConnAckProperties::read(&mut buf.into()).unwrap(); + + let mut buf = bytes::BytesMut::new(); + + let variable_length = c1.wire_len(); + assert_eq!(variable_length, 56); + + c1.write(&mut buf).unwrap(); + + let _buf_clone = buf.to_vec(); + + let c2 = ConnAckProperties::read(&mut buf.into()).unwrap(); + + assert_eq!(c1, c2); + } +} diff --git a/mqrstt/src/packets/connack.rs b/mqrstt/src/packets/connack/properties.rs similarity index 56% rename from mqrstt/src/packets/connack.rs rename to mqrstt/src/packets/connack/properties.rs index ef6e92d..9b32e8e 100644 --- a/mqrstt/src/packets/connack.rs +++ b/mqrstt/src/packets/connack/properties.rs @@ -1,616 +1,326 @@ -use super::{ - error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, - read_variable_integer, - reason_codes::ConnAckReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, -}; -use bytes::{Buf, BufMut}; - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct ConnAck { - /// 3.2.2.1 Connect Acknowledge Flags - pub connack_flags: ConnAckFlags, - - /// 3.2.2.2 Connect Reason Code - /// Byte 2 in the Variable Header is the Connect Reason Code. - pub reason_code: ConnAckReasonCode, - - /// 3.2.2.3 CONNACK Properties - pub connack_properties: ConnAckProperties, -} - -impl PacketRead for ConnAck { - fn read(_: u8, header_len: usize, mut buf: bytes::Bytes) -> Result { - if header_len > buf.len() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), header_len)); - } - - let connack_flags = ConnAckFlags::read(&mut buf)?; - let reason_code = ConnAckReasonCode::read(&mut buf)?; - let connack_properties = ConnAckProperties::read(&mut buf)?; - - Ok(Self { - connack_flags, - reason_code, - connack_properties, - }) - } -} - -impl PacketAsyncRead for ConnAck where S: tokio::io::AsyncReadExt + Unpin { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; - let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; - let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; - - Ok(( - Self { - connack_flags, - reason_code, - connack_properties, - }, - read_bytes + reason_code_read_bytes + connack_properties_read_bytes - )) - - } - } -} - -impl PacketWrite for ConnAck { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - self.connack_flags.write(buf)?; - self.reason_code.write(buf)?; - self.connack_properties.write(buf)?; - - Ok(()) - } -} - -impl WireLength for ConnAck { - fn wire_len(&self) -> usize { - 2 + // 1 for connack_flags and 1 for reason_code - variable_integer_len(self.connack_properties.wire_len()) + - self.connack_properties.wire_len() - } -} - -super::macros::define_properties!(ConnAckProperties, - SessionExpiryInterval, - ReceiveMaximum, - MaximumQos, - RetainAvailable, - MaximumPacketSize, - AssignedClientIdentifier, - TopicAliasMaximum, - ReasonString, - UserProperty, - WildcardSubscriptionAvailable, - SubscriptionIdentifierAvailable, - SharedSubscriptionAvailable, - ServerKeepAlive, - ResponseInformation, - ServerReference, - AuthenticationMethod, - AuthenticationData -); - -// #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -// pub struct ConnAckProperties { -// /// 3.2.2.3.2 Session Expiry Interval -// /// 17 (0x11) Byte Identifier of the Session Expiry Interval -// pub session_expiry_interval: Option, - -// /// 3.2.2.3.3 Receive Maximum -// /// 33 (0x21) Byte, Identifier of the Receive Maximum -// pub receive_maximum: Option, - -// /// 3.2.2.3.4 Maximum QoS -// /// 36 (0x24) Byte, Identifier of the Maximum QoS. -// pub maximum_qos: Option, - -// /// 3.2.2.3.5 Retain Available -// /// 37 (0x25) Byte, Identifier of Retain Available. -// pub retain_available: Option, - -// /// 3.2.2.3.6 Maximum Packet Size -// /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. -// pub maximum_packet_size: Option, - -// /// 3.2.2.3.7 Assigned Client Identifier -// /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. -// pub assigned_client_id: Option>, - -// /// 3.2.2.3.8 Topic Alias Maximum -// /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. -// pub topic_alias_maximum: Option, - -// /// 3.2.2.3.9 Reason String -// /// 31 (0x1F) Byte Identifier of the Reason String. -// pub reason_string: Option>, - -// /// 3.2.2.3.10 User Property -// /// 38 (0x26) Byte, Identifier of User Property. -// pub user_properties: Vec<(Box, Box)>, - -// /// 3.2.2.3.11 Wildcard Subscription Available -// /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. -// pub wildcards_available: Option, - -// /// 3.2.2.3.12 Subscription Identifiers Available -// /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. -// pub subscription_ids_available: Option, - -// /// 3.2.2.3.13 Shared Subscription Available -// /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. -// pub shared_subscription_available: Option, - -// /// 3.2.2.3.14 Server Keep Alive -// /// 19 (0x13) Byte, Identifier of the Server Keep Alive -// pub server_keep_alive: Option, - -// /// 3.2.2.3.15 Response Information -// /// 26 (0x1A) Byte, Identifier of the Response Information. -// pub response_info: Option>, - -// /// 3.2.2.3.16 Server Reference -// /// 28 (0x1C) Byte, Identifier of the Server Reference -// pub server_reference: Option>, - -// /// 3.2.2.3.17 Authentication Method -// /// 21 (0x15) Byte, Identifier of the Authentication Method -// pub authentication_method: Option>, - -// /// 3.2.2.3.18 Authentication Data -// /// 22 (0x16) Byte, Identifier of the Authentication Data -// // There is a small inconsistency here with authentication_data in the connect packet. -// // This is Option while that type uses just Bytes. -// pub authentication_data: Option>, -// } - -impl MqttRead for ConnAckProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf).map_err(DeserializeError::from)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - let property = PropertyType::read(&mut property_data)?; - match property { - PropertyType::SessionExpiryInterval => { - if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.session_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::ReceiveMaximum => { - if properties.receive_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); - } - properties.receive_maximum = Some(u16::read(&mut property_data)?); - } - PropertyType::MaximumQos => { - if properties.maximum_qos.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumQos)); - } - properties.maximum_qos = Some(QoS::read(&mut property_data)?); - } - PropertyType::RetainAvailable => { - if properties.retain_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable)); - } - properties.retain_available = Some(bool::read(&mut property_data)?); - } - PropertyType::MaximumPacketSize => { - if properties.maximum_packet_size.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); - } - properties.maximum_packet_size = Some(u32::read(&mut property_data)?); - } - PropertyType::AssignedClientIdentifier => { - if properties.assigned_client_id.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier)); - } - properties.assigned_client_id = Some(Box::::read(&mut property_data)?); - } - PropertyType::TopicAliasMaximum => { - if properties.topic_alias_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); - } - properties.topic_alias_maximum = Some(u16::read(&mut property_data)?); - } - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::WildcardSubscriptionAvailable => { - if properties.wildcards_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable)); - } - properties.wildcards_available = Some(bool::read(&mut property_data)?); - } - PropertyType::SubscriptionIdentifierAvailable => { - if properties.subscription_ids_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable)); - } - properties.subscription_ids_available = Some(bool::read(&mut property_data)?); - } - PropertyType::SharedSubscriptionAvailable => { - if properties.shared_subscription_available.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable)); - } - properties.shared_subscription_available = Some(bool::read(&mut property_data)?); - } - PropertyType::ServerKeepAlive => { - if properties.server_keep_alive.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive)); - } - properties.server_keep_alive = Some(u16::read(&mut property_data)?); - } - PropertyType::ResponseInformation => { - if properties.response_info.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation)); - } - properties.response_info = Some(Box::::read(&mut property_data)?); - } - PropertyType::ServerReference => { - if properties.server_reference.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); - } - properties.server_reference = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationMethod => { - if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); - } - properties.authentication_method = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationData => { - if properties.authentication_data.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); - } - properties.authentication_data = Some(Vec::::read(&mut property_data)?); - } - - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::ConnAck)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for ConnAckProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - let Self { - session_expiry_interval, - receive_maximum, - maximum_qos, - retain_available, - maximum_packet_size, - assigned_client_id, - topic_alias_maximum, - reason_string, - user_properties, - wildcards_available, - subscription_ids_available, - shared_subscription_available, - server_keep_alive, - response_info, - server_reference, - authentication_method, - authentication_data, - } = self; - - if let Some(session_expiry_interval) = session_expiry_interval { - PropertyType::SessionExpiryInterval.write(buf)?; - buf.put_u32(*session_expiry_interval); - } - if let Some(receive_maximum) = receive_maximum { - PropertyType::ReceiveMaximum.write(buf)?; - buf.put_u16(*receive_maximum); - } - if let Some(maximum_qos) = maximum_qos { - PropertyType::MaximumQos.write(buf)?; - maximum_qos.write(buf)?; - } - if let Some(retain_available) = retain_available { - PropertyType::RetainAvailable.write(buf)?; - retain_available.write(buf)?; - } - if let Some(maximum_packet_size) = maximum_packet_size { - PropertyType::MaximumPacketSize.write(buf)?; - buf.put_u32(*maximum_packet_size); - } - if let Some(client_id) = assigned_client_id { - PropertyType::AssignedClientIdentifier.write(buf)?; - client_id.write(buf)?; - } - if let Some(topic_alias_maximum) = topic_alias_maximum { - PropertyType::TopicAliasMaximum.write(buf)?; - buf.put_u16(*topic_alias_maximum); - } - if let Some(reason_string) = reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, val) in user_properties.iter() { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - val.write(buf)?; - } - if let Some(wildcards_available) = wildcards_available { - PropertyType::WildcardSubscriptionAvailable.write(buf)?; - wildcards_available.write(buf)?; - } - if let Some(subscription_ids_available) = subscription_ids_available { - PropertyType::SubscriptionIdentifierAvailable.write(buf)?; - subscription_ids_available.write(buf)?; - } - if let Some(shared_subscription_available) = shared_subscription_available { - PropertyType::SharedSubscriptionAvailable.write(buf)?; - shared_subscription_available.write(buf)?; - } - if let Some(server_keep_alive) = server_keep_alive { - PropertyType::ServerKeepAlive.write(buf)?; - server_keep_alive.write(buf)?; - } - if let Some(response_info) = response_info { - PropertyType::ResponseInformation.write(buf)?; - response_info.write(buf)?; - } - if let Some(server_reference) = server_reference { - PropertyType::ServerReference.write(buf)?; - server_reference.write(buf)?; - } - if let Some(authentication_method) = &authentication_method { - PropertyType::AuthenticationMethod.write(buf)?; - authentication_method.write(buf)?; - } - if let Some(authentication_data) = authentication_data { - if authentication_method.is_none() { - return Err(SerializeError::AuthDataWithoutAuthMethod); - } - PropertyType::AuthenticationData.write(buf)?; - authentication_data.write(buf)?; - } - - Ok(()) - } -} - -// impl WireLength for ConnAckProperties { -// fn wire_len(&self) -> usize { -// let mut len: usize = 0; - -// if self.session_expiry_interval.is_some() { -// len += 1 + 4; -// } -// if self.receive_maximum.is_some() { -// len += 1 + 2; -// } -// if self.maximum_qos.is_some() { -// len += 1 + 1; -// } -// if self.retain_available.is_some() { -// len += 1 + 1; -// } -// if self.maximum_packet_size.is_some() { -// len += 1 + 4; -// } -// if let Some(client_id) = &self.assigned_client_id { -// len += 1 + client_id.wire_len(); -// } -// if self.topic_alias_maximum.is_some() { -// len += 1 + 2; -// } -// if let Some(reason_string) = &self.reason_string { -// len += 1 + reason_string.wire_len(); -// } -// for (key, value) in &self.user_properties { -// len += 1; -// len += key.wire_len(); -// len += value.wire_len(); -// } -// if self.wildcards_available.is_some() { -// len += 1 + 1; -// } -// if self.subscription_ids_available.is_some() { -// len += 1 + 1; -// } -// if self.shared_subscription_available.is_some() { -// len += 1 + 1; -// } -// if self.server_keep_alive.is_some() { -// len += 1 + 2; -// } -// if let Some(response_info) = &self.response_info { -// len += 1 + response_info.wire_len(); -// } -// if let Some(server_reference) = &self.server_reference { -// len += 1 + server_reference.wire_len(); -// } -// if let Some(authentication_method) = &self.authentication_method { -// len += 1 + authentication_method.wire_len(); -// } -// if self.authentication_data.is_some() && self.authentication_method.is_some() { -// len += 1 + self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); -// } - -// len -// } -// } - -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct ConnAckFlags { - pub session_present: bool, -} - -impl MqttAsyncRead for ConnAckFlags where S: tokio::io::AsyncReadExt + Unpin { - fn async_read(stream: &mut S) -> impl std::future::Future> { - async move { - let byte = stream.read_u8().await?; - Ok((Self { - session_present: (byte & 0b00000001) == 0b00000001, - }, 1)) - } - } -} - -impl MqttRead for ConnAckFlags { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - let byte = buf.get_u8(); - - Ok(Self { - session_present: (byte & 0b00000001) == 0b00000001, - }) - } -} - -impl MqttWrite for ConnAckFlags { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let byte = self.session_present as u8; - - buf.put_u8(byte); - Ok(()) - } -} - -#[cfg(test)] -mod tests { - - use crate::packets::{ - connack::{ConnAck, ConnAckProperties}, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - reason_codes::ConnAckReasonCode, - Packet, - }; - - #[test] - fn test_wire_len() { - let mut buf = bytes::BytesMut::new(); - - let connack_properties = ConnAckProperties { - session_expiry_interval: Some(60), // Session expiry interval in seconds - receive_maximum: Some(20), // Maximum number of QoS 1 and QoS 2 publications that the client is willing to process concurrently - maximum_qos: Some(crate::packets::QoS::AtMostOnce), // Maximum QoS level supported by the server - retain_available: Some(true), // Whether the server supports retained messages - maximum_packet_size: Some(1024), // Maximum packet size the server is willing to accept - assigned_client_id: Some(Box::from("client-12345")), // Client identifier assigned by the server - topic_alias_maximum: Some(10), // Maximum number of topic aliases supported by the server - reason_string: Some(Box::from("Connection accepted")), // Reason string for the connection acknowledgment - user_properties: vec![(Box::from("key1"), Box::from("value1"))], // User property key-value pair - wildcards_available: Some(true), // Whether wildcard subscriptions are available - subscription_ids_available: Some(true), // Whether subscription identifiers are available - shared_subscription_available: Some(true), // Whether shared subscriptions are available - server_keep_alive: Some(120), // Server keep alive time in seconds - response_info: Some(Box::from("Response info")), // Response information - server_reference: Some(Box::from("server-reference")), // Server reference - authentication_method: Some(Box::from("auth-method")), // Authentication method - authentication_data: Some(vec![1, 2, 3, 4]), // Authentication data - }; - - let len = connack_properties.wire_len(); - // determine length of variable integer - let len_of_wire_len = crate::packets::write_variable_integer(&mut buf, len).unwrap(); - // clear buffer before writing actual properties - buf.clear(); - connack_properties.write(&mut buf).unwrap(); - - assert_eq!(len + len_of_wire_len, buf.len()); - - } - - #[test] - fn read_write_connack_packet() { - let c = ConnAck { ..Default::default() }; - - let p1 = Packet::ConnAck(c); - let mut buf = bytes::BytesMut::new(); - - p1.write(&mut buf).unwrap(); - - let p2 = Packet::read_from_buffer(&mut buf).unwrap(); - - assert_eq!(p1, p2); - } - - #[test] - fn read_write_connack() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - 0x01, // Connack flags - 0x00, // Reason code, - 0x00, // empty properties - ]; - - buf.extend_from_slice(packet); - let c1 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); - - assert_eq!(ConnAckReasonCode::Success, c1.reason_code); - assert_eq!(ConnAckProperties::default(), c1.connack_properties); - - let mut buf = bytes::BytesMut::new(); - - c1.write(&mut buf).unwrap(); - - let c2 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); - - assert_eq!(c1, c2) - } - - #[test] - fn read_write_connack_properties() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - 56, // ConnAckProperties variable length - 17, // session_expiry_interval - 0xff, 0xff, 37, // retain_available - 0x1, // true - 18, // Assigned Client Id - 0, 11, // 11 bytes "KeanuReeves" without space - b'K', b'e', b'a', b'n', b'u', b'R', b'e', b'e', b'v', b'e', b's', 36, // Max QoS - 2, // QoS 2 Exactly Once - 34, // Topic Alias Max = 255 - 0, 255, 31, // Reason String = 'Houston we have got a problem' - 0, 29, b'H', b'o', b'u', b's', b't', b'o', b'n', b' ', b'w', b'e', b' ', b'h', b'a', b'v', b'e', b' ', b'g', b'o', b't', b' ', b'a', b' ', b'p', b'r', b'o', b'b', b'l', b'e', b'm', - ]; - - buf.extend_from_slice(packet); - let c1 = ConnAckProperties::read(&mut buf.into()).unwrap(); - - let mut buf = bytes::BytesMut::new(); - - let variable_length = c1.wire_len(); - assert_eq!(variable_length, 56); - - c1.write(&mut buf).unwrap(); - - let _buf_clone = buf.to_vec(); - - let c2 = ConnAckProperties::read(&mut buf.into()).unwrap(); - - assert_eq!(c1, c2); - } -} +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, QoS, VariableInteger, +}; +use bytes::{Buf, BufMut}; + +crate::packets::macros::define_properties!(ConnAckProperties, + SessionExpiryInterval, + ReceiveMaximum, + MaximumQos, + RetainAvailable, + MaximumPacketSize, + AssignedClientIdentifier, + TopicAliasMaximum, + ReasonString, + UserProperty, + WildcardSubscriptionAvailable, + SubscriptionIdentifierAvailable, + SharedSubscriptionAvailable, + ServerKeepAlive, + ResponseInformation, + ServerReference, + AuthenticationMethod, + AuthenticationData +); + +// #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +// pub struct ConnAckProperties { +// /// 3.2.2.3.2 Session Expiry Interval +// /// 17 (0x11) Byte Identifier of the Session Expiry Interval +// pub session_expiry_interval: Option, + +// /// 3.2.2.3.3 Receive Maximum +// /// 33 (0x21) Byte, Identifier of the Receive Maximum +// pub receive_maximum: Option, + +// /// 3.2.2.3.4 Maximum QoS +// /// 36 (0x24) Byte, Identifier of the Maximum QoS. +// pub maximum_qos: Option, + +// /// 3.2.2.3.5 Retain Available +// /// 37 (0x25) Byte, Identifier of Retain Available. +// pub retain_available: Option, + +// /// 3.2.2.3.6 Maximum Packet Size +// /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. +// pub maximum_packet_size: Option, + +// /// 3.2.2.3.7 Assigned Client Identifier +// /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. +// pub assigned_client_id: Option>, + +// /// 3.2.2.3.8 Topic Alias Maximum +// /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. +// pub topic_alias_maximum: Option, + +// /// 3.2.2.3.9 Reason String +// /// 31 (0x1F) Byte Identifier of the Reason String. +// pub reason_string: Option>, + +// /// 3.2.2.3.10 User Property +// /// 38 (0x26) Byte, Identifier of User Property. +// pub user_properties: Vec<(Box, Box)>, + +// /// 3.2.2.3.11 Wildcard Subscription Available +// /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. +// pub wildcards_available: Option, + +// /// 3.2.2.3.12 Subscription Identifiers Available +// /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. +// pub subscription_ids_available: Option, + +// /// 3.2.2.3.13 Shared Subscription Available +// /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. +// pub shared_subscription_available: Option, + +// /// 3.2.2.3.14 Server Keep Alive +// /// 19 (0x13) Byte, Identifier of the Server Keep Alive +// pub server_keep_alive: Option, + +// /// 3.2.2.3.15 Response Information +// /// 26 (0x1A) Byte, Identifier of the Response Information. +// pub response_info: Option>, + +// /// 3.2.2.3.16 Server Reference +// /// 28 (0x1C) Byte, Identifier of the Server Reference +// pub server_reference: Option>, + +// /// 3.2.2.3.17 Authentication Method +// /// 21 (0x15) Byte, Identifier of the Authentication Method +// pub authentication_method: Option>, + +// /// 3.2.2.3.18 Authentication Data +// /// 22 (0x16) Byte, Identifier of the Authentication Data +// // There is a small inconsistency here with authentication_data in the connect packet. +// // This is Option while that type uses just Bytes. +// pub authentication_data: Option>, +// } + +impl MqttRead for ConnAckProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + let property = PropertyType::read(&mut property_data)?; + match property { + PropertyType::SessionExpiryInterval => { + if properties.session_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.session_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::ReceiveMaximum => { + if properties.receive_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); + } + properties.receive_maximum = Some(u16::read(&mut property_data)?); + } + PropertyType::MaximumQos => { + if properties.maximum_qos.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumQos)); + } + properties.maximum_qos = Some(QoS::read(&mut property_data)?); + } + PropertyType::RetainAvailable => { + if properties.retain_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable)); + } + properties.retain_available = Some(bool::read(&mut property_data)?); + } + PropertyType::MaximumPacketSize => { + if properties.maximum_packet_size.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); + } + properties.maximum_packet_size = Some(u32::read(&mut property_data)?); + } + PropertyType::AssignedClientIdentifier => { + if properties.assigned_client_id.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier)); + } + properties.assigned_client_id = Some(Box::::read(&mut property_data)?); + } + PropertyType::TopicAliasMaximum => { + if properties.topic_alias_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); + } + properties.topic_alias_maximum = Some(u16::read(&mut property_data)?); + } + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::WildcardSubscriptionAvailable => { + if properties.wildcards_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable)); + } + properties.wildcards_available = Some(bool::read(&mut property_data)?); + } + PropertyType::SubscriptionIdentifierAvailable => { + if properties.subscription_ids_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable)); + } + properties.subscription_ids_available = Some(bool::read(&mut property_data)?); + } + PropertyType::SharedSubscriptionAvailable => { + if properties.shared_subscription_available.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable)); + } + properties.shared_subscription_available = Some(bool::read(&mut property_data)?); + } + PropertyType::ServerKeepAlive => { + if properties.server_keep_alive.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive)); + } + properties.server_keep_alive = Some(u16::read(&mut property_data)?); + } + PropertyType::ResponseInformation => { + if properties.response_info.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation)); + } + properties.response_info = Some(Box::::read(&mut property_data)?); + } + PropertyType::ServerReference => { + if properties.server_reference.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); + } + properties.server_reference = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationMethod => { + if properties.authentication_method.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); + } + properties.authentication_method = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationData => { + if properties.authentication_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); + } + properties.authentication_data = Some(Vec::::read(&mut property_data)?); + } + + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::ConnAck)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for ConnAckProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + let Self { + session_expiry_interval, + receive_maximum, + maximum_qos, + retain_available, + maximum_packet_size, + assigned_client_id, + topic_alias_maximum, + reason_string, + user_properties, + wildcards_available, + subscription_ids_available, + shared_subscription_available, + server_keep_alive, + response_info, + server_reference, + authentication_method, + authentication_data, + } = self; + + if let Some(session_expiry_interval) = session_expiry_interval { + PropertyType::SessionExpiryInterval.write(buf)?; + buf.put_u32(*session_expiry_interval); + } + if let Some(receive_maximum) = receive_maximum { + PropertyType::ReceiveMaximum.write(buf)?; + buf.put_u16(*receive_maximum); + } + if let Some(maximum_qos) = maximum_qos { + PropertyType::MaximumQos.write(buf)?; + maximum_qos.write(buf)?; + } + if let Some(retain_available) = retain_available { + PropertyType::RetainAvailable.write(buf)?; + retain_available.write(buf)?; + } + if let Some(maximum_packet_size) = maximum_packet_size { + PropertyType::MaximumPacketSize.write(buf)?; + buf.put_u32(*maximum_packet_size); + } + if let Some(client_id) = assigned_client_id { + PropertyType::AssignedClientIdentifier.write(buf)?; + client_id.write(buf)?; + } + if let Some(topic_alias_maximum) = topic_alias_maximum { + PropertyType::TopicAliasMaximum.write(buf)?; + buf.put_u16(*topic_alias_maximum); + } + if let Some(reason_string) = reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, val) in user_properties.iter() { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + val.write(buf)?; + } + if let Some(wildcards_available) = wildcards_available { + PropertyType::WildcardSubscriptionAvailable.write(buf)?; + wildcards_available.write(buf)?; + } + if let Some(subscription_ids_available) = subscription_ids_available { + PropertyType::SubscriptionIdentifierAvailable.write(buf)?; + subscription_ids_available.write(buf)?; + } + if let Some(shared_subscription_available) = shared_subscription_available { + PropertyType::SharedSubscriptionAvailable.write(buf)?; + shared_subscription_available.write(buf)?; + } + if let Some(server_keep_alive) = server_keep_alive { + PropertyType::ServerKeepAlive.write(buf)?; + server_keep_alive.write(buf)?; + } + if let Some(response_info) = response_info { + PropertyType::ResponseInformation.write(buf)?; + response_info.write(buf)?; + } + if let Some(server_reference) = server_reference { + PropertyType::ServerReference.write(buf)?; + server_reference.write(buf)?; + } + if let Some(authentication_method) = &authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if let Some(authentication_data) = authentication_data { + if authentication_method.is_none() { + return Err(SerializeError::AuthDataWithoutAuthMethod); + } + PropertyType::AuthenticationData.write(buf)?; + authentication_data.write(buf)?; + } + + Ok(()) + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/connack/reason_code.rs b/mqrstt/src/packets/connack/reason_code.rs new file mode 100644 index 0000000..51edabe --- /dev/null +++ b/mqrstt/src/packets/connack/reason_code.rs @@ -0,0 +1,24 @@ +crate::packets::macros::reason_code!(ConnAckReasonCode, + Success, + UnspecifiedError, + MalformedPacket, + ProtocolError, + ImplementationSpecificError, + UnsupportedProtocolVersion, + ClientIdentifierNotValid, + BadUsernameOrPassword, + NotAuthorized, + ServerUnavailable, + ServerBusy, + Banned, + BadAuthenticationMethod, + TopicNameInvalid, + PacketTooLarge, + QuotaExceeded, + PayloadFormatInvalid, + RetainNotSupported, + QosNotSupported, + UseAnotherServer, + ServerMoved, + ConnectionRateExceeded +); \ No newline at end of file diff --git a/mqrstt/src/packets/connect.rs b/mqrstt/src/packets/connect.rs deleted file mode 100644 index c13d1aa..0000000 --- a/mqrstt/src/packets/connect.rs +++ /dev/null @@ -1,892 +0,0 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; - -use super::{ - error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, ProtocolVersion, QoS, WireLength, -}; - -/// Variable connect header: -/// -/// -/// ╔═══════════╦═══════════════════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╗ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ ║ Description ║ 7 ║ 6 ║ 5 ║ 4 ║ 3 ║ 2 ║ 1 ║ 0 ║ -/// ╠═══════════╩═══════════════════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╣ -/// ║ ║ -/// ║ Protocol Name ║ -/// ╠═══════════╦═══════════════════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╦══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 1 ║ Length MSB (0) ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 2 ║ Length LSB (4) ║ 0 ║ 0 ║ 0 ║ 0 ║ 0 ║ 1 ║ 0 ║ 0 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 3 ║ ‘M’ ║ 0 ║ 1 ║ 0 ║ 0 ║ 1 ║ 1 ║ 0 ║ 1 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 4 ║ ‘Q’ ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 0 ║ 0 ║ 1 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 5 ║ ‘T’ ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 0 ║ -/// ╠═══════════╬═══════════════════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╬══════╣ -/// ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ ║ -/// ║ byte 6 ║ ‘T’ ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 1 ║ 0 ║ 0 ║ -/// ╚═══════════╩═══════════════════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╩══════╝ -/// -/// Byte 7: -/// The protocol version -/// -/// Byte 8: -/// 3.1.2.3 Connect Flags : -/// ╔═════╦═══════════╦══════════╦═════════════╦═════╦════╦═══════════╦═════════════╦══════════╗ -/// ║ Bit ║ 7 ║ 6 ║ 5 ║ 4 ║ 3 ║ 2 ║ 1 ║ 0 ║ -/// ╠═════╬═══════════╬══════════╬═════════════╬═════╩════╬═══════════╬═════════════╬══════════╣ -/// ║ ║ User Name ║ Password ║ Will Retain ║ Will QoS ║ Will Flag ║ Clean Start ║ Reserved ║ -/// ╚═════╩═══════════╩══════════╩═════════════╩══════════╩═══════════╩═════════════╩══════════╝ -/// -/// Byte 9 and 10: -/// The keep alive -/// -/// Byte 11: -/// Length of [`ConnectProperties`] -/// -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Connect { - /// Byte 7 - pub protocol_version: ProtocolVersion, - - /// 3.1.2.4 Clean Start Flag - /// bit 1 - pub clean_start: bool, - - /// 3.1.2.5 Will Flag through option - pub last_will: Option, - - /// 3.1.2.8 User Name Flag - pub username: Option>, - /// 3.1.2.9 Password Flag - pub password: Option>, - - /// 3.1.2.10 Keep Alive - /// Byte 9 and 10 - pub keep_alive: u16, - - /// 3.1.2.11 CONNECT Properties - pub connect_properties: ConnectProperties, - - /// 3.1.3.1 Client Identifier (ClientID) - pub client_id: Box, -} - -impl Default for Connect { - fn default() -> Self { - Self { - protocol_version: ProtocolVersion::V5, - clean_start: true, - last_will: None, - username: None, - password: None, - keep_alive: 60, - connect_properties: ConnectProperties::default(), - client_id: "MQRSTT".into(), - } - } -} - -impl PacketRead for Connect { - fn read(_: u8, _: usize, mut buf: Bytes) -> Result { - if String::read(&mut buf)? != "MQTT" { - return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_string())); - } - - let protocol_version = ProtocolVersion::read(&mut buf)?; - - let connect_flags = ConnectFlags::read(&mut buf)?; - - let clean_start = connect_flags.clean_start; - let keep_alive = buf.get_u16(); - - let connect_properties = ConnectProperties::read(&mut buf)?; - - let client_id = Box::::read(&mut buf)?; - let mut last_will = None; - if connect_flags.will_flag { - let retain = connect_flags.will_retain; - - last_will = Some(LastWill::read(connect_flags.will_qos, retain, &mut buf)?); - } - - let username = if connect_flags.username { Some(Box::::read(&mut buf)?) } else { None }; - let password = if connect_flags.password { Some(Box::::read(&mut buf)?) } else { None }; - - let connect = Connect { - protocol_version, - clean_start, - last_will, - username, - password, - keep_alive, - connect_properties, - client_id, - }; - - Ok(connect) - } -} - -impl PacketWrite for Connect { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - "MQTT".write(buf)?; - - self.protocol_version.write(buf)?; - - let mut connect_flags = ConnectFlags { - clean_start: self.clean_start, - ..Default::default() - }; - - if let Some(last_will) = &self.last_will { - connect_flags.will_flag = true; - connect_flags.will_retain = last_will.retain; - connect_flags.will_qos = last_will.qos; - } - connect_flags.username = self.username.is_some(); - connect_flags.password = self.password.is_some(); - - connect_flags.write(buf)?; - - buf.put_u16(self.keep_alive); - - self.connect_properties.write(buf)?; - - self.client_id.write(buf)?; - - if let Some(last_will) = &self.last_will { - last_will.write(buf)?; - } - if let Some(username) = &self.username { - username.write(buf)?; - } - if let Some(password) = &self.password { - password.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for Connect { - fn wire_len(&self) -> usize { - let mut len = "MQTT".wire_len() + 1 + 1 + 2; // protocol version, connect_flags and keep alive - - len += variable_integer_len(self.connect_properties.wire_len()); - len += self.connect_properties.wire_len(); - - if let Some(last_will) = &self.last_will { - len += last_will.wire_len(); - } - if let Some(username) = &self.username { - len += username.wire_len() - } - if let Some(password) = &self.password { - len += password.wire_len() - } - - len += self.client_id.wire_len(); - - len - } -} - -/// ╔═════╦═══════════╦══════════╦═════════════╦═════╦════╦═══════════╦═════════════╦══════════╗ -/// ║ Bit ║ 7 ║ 6 ║ 5 ║ 4 ║ 3 ║ 2 ║ 1 ║ 0 ║ -/// ╠═════╬═══════════╬══════════╬═════════════╬═════╩════╬═══════════╬═════════════╬══════════╣ -/// ║ ║ User Name ║ Password ║ Will Retain ║ Will QoS ║ Will Flag ║ Clean Start ║ Reserved ║ -/// ╚═════╩═══════════╩══════════╩═════════════╩══════════╩═══════════╩═════════════╩══════════╝ -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct ConnectFlags { - pub clean_start: bool, - pub will_flag: bool, - pub will_qos: QoS, - pub will_retain: bool, - pub password: bool, - pub username: bool, -} - -impl ConnectFlags { - pub fn from_u8(value: u8) -> Result { - Ok(Self { - clean_start: ((value & 0b00000010) >> 1) != 0, - will_flag: ((value & 0b00000100) >> 2) != 0, - will_qos: QoS::from_u8((value & 0b00011000) >> 3)?, - will_retain: ((value & 0b00100000) >> 5) != 0, - password: ((value & 0b01000000) >> 6) != 0, - username: ((value & 0b10000000) >> 7) != 0, - }) - } - - pub fn into_u8(&self) -> Result { - let byte = ((self.clean_start as u8) << 1) - | ((self.will_flag as u8) << 2) - | (self.will_qos.into_u8() << 3) - | ((self.will_retain as u8) << 5) - | ((self.password as u8) << 6) - | ((self.username as u8) << 7); - Ok(byte) - } -} - -impl Default for ConnectFlags { - fn default() -> Self { - Self { - clean_start: false, - will_flag: false, - will_qos: QoS::AtMostOnce, - will_retain: false, - password: false, - username: false, - } - } -} - -impl MqttRead for ConnectFlags { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - let byte = buf.get_u8(); - - ConnectFlags::from_u8(byte) - } -} - -impl MqttWrite for ConnectFlags { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - buf.put_u8(self.into_u8()?); - Ok(()) - } -} - -/// Connect Properties -/// -/// The wire representation starts with the length of all properties after which -/// the identifiers and their actual value are given -/// -/// 3.1.2.11.1 Property Length -/// The length of the Properties in the CONNECT packet Variable Header encoded as a Variable Byte Integer. -/// Followed by all possible connect properties: -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct ConnectProperties { - /// 3.1.2.11.2 Session Expiry Interval - /// 17 (0x11) Byte Identifier of the Session Expiry Interval - pub session_expiry_interval: Option, - - /// 3.1.2.11.3 Receive Maximum - /// 33 (0x21) Byte, Identifier of the Receive Maximum - pub receive_maximum: Option, - - /// 3.1.2.11.4 Maximum Packet Size - /// 39 (0x27) Byte, Identifier of the Maximum Packet Size - pub maximum_packet_size: Option, - - /// 3.1.2.11.5 Topic Alias Maximum - /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum - pub topic_alias_maximum: Option, - - /// 3.1.2.11.6 Request Response Information - /// 25 (0x19) Byte, Identifier of the Request Response Information - pub request_response_information: Option, - - /// 3.1.2.11.7 Request Problem Information - /// 23 (0x17) Byte, Identifier of the Request Problem Information - pub request_problem_information: Option, - - /// 3.1.2.11.8 User Property - /// 38 (0x26) Byte, Identifier of the User Property - pub user_properties: Vec<(Box, Box)>, - - /// 3.1.2.11.9 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method - pub authentication_method: Option>, - - /// 3.1.2.11.10 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - pub authentication_data: Bytes, -} - -impl MqttWrite for ConnectProperties { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(session_expiry_interval) = self.session_expiry_interval { - PropertyType::SessionExpiryInterval.write(buf)?; - buf.put_u32(session_expiry_interval); - } - if let Some(receive_maximum) = self.receive_maximum { - PropertyType::ReceiveMaximum.write(buf)?; - buf.put_u16(receive_maximum); - } - if let Some(maximum_packet_size) = self.maximum_packet_size { - PropertyType::MaximumPacketSize.write(buf)?; - buf.put_u32(maximum_packet_size); - } - if let Some(topic_alias_maximum) = self.topic_alias_maximum { - PropertyType::TopicAliasMaximum.write(buf)?; - buf.put_u16(topic_alias_maximum); - } - if let Some(request_response_information) = self.request_response_information { - PropertyType::RequestResponseInformation.write(buf)?; - buf.put_u8(request_response_information); - } - if let Some(request_problem_information) = self.request_problem_information { - PropertyType::RequestProblemInformation.write(buf)?; - buf.put_u8(request_problem_information); - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - if let Some(authentication_method) = &self.authentication_method { - PropertyType::AuthenticationMethod.write(buf)?; - authentication_method.write(buf)?; - } - if !self.authentication_data.is_empty() { - if self.authentication_method.is_none() { - return Err(SerializeError::AuthDataWithoutAuthMethod); - } - PropertyType::AuthenticationData.write(buf)?; - self.authentication_data.write(buf)?; - } - - Ok(()) - } -} - -impl MqttRead for ConnectProperties { - fn read(buf: &mut Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut property_data)? { - PropertyType::SessionExpiryInterval => { - if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.session_expiry_interval = Some(property_data.get_u32()); - } - PropertyType::ReceiveMaximum => { - if properties.receive_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); - } - properties.receive_maximum = Some(property_data.get_u16()); - } - PropertyType::MaximumPacketSize => { - if properties.maximum_packet_size.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); - } - properties.maximum_packet_size = Some(property_data.get_u32()); - } - PropertyType::TopicAliasMaximum => { - if properties.topic_alias_maximum.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); - } - properties.topic_alias_maximum = Some(property_data.get_u16()); - } - PropertyType::RequestResponseInformation => { - if properties.request_response_information.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation)); - } - properties.request_response_information = Some(property_data.get_u8()); - } - PropertyType::RequestProblemInformation => { - if properties.request_problem_information.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::RequestProblemInformation)); - } - properties.request_problem_information = Some(property_data.get_u8()); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::AuthenticationMethod => { - if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); - } - properties.authentication_method = Some(Box::::read(&mut property_data)?); - } - PropertyType::AuthenticationData => { - if properties.authentication_data.is_empty() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); - } - properties.authentication_data = Bytes::read(&mut property_data)?; - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), - } - - if property_data.is_empty() { - break; - } - } - - if !properties.authentication_data.is_empty() && properties.authentication_method.is_none() { - return Err(DeserializeError::MalformedPacketWithInfo("Authentication data is not empty while authentication method is".to_string())); - } - - Ok(properties) - } -} - -impl WireLength for ConnectProperties { - fn wire_len(&self) -> usize { - let mut len: usize = 0; - - if self.session_expiry_interval.is_some() { - len += 1 + 4; - } - if self.receive_maximum.is_some() { - len += 1 + 2; - } - if self.maximum_packet_size.is_some() { - len += 1 + 4; - } - if self.topic_alias_maximum.is_some() { - len += 1 + 2; - } - if self.request_response_information.is_some() { - len += 2; - } - if self.request_problem_information.is_some() { - len += 2; - } - for (key, value) in &self.user_properties { - len += 1; - len += key.wire_len(); - len += value.wire_len(); - } - if let Some(authentication_method) = &self.authentication_method { - len += 1 + authentication_method.wire_len(); - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - len += 1 + self.authentication_data.wire_len(); - } - - len - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct LastWill { - /// 3.1.2.6 Will QoS - pub qos: QoS, - /// 3.1.2.7 Will Retain - pub retain: bool, - - /// 3.1.3.2 Will properties - pub last_will_properties: LastWillProperties, - /// 3.1.3.3 Will Topic - pub topic: Box, - /// 3.1.3.4 Will payload - pub payload: Bytes, -} - -impl LastWill { - pub fn new, P: Into>>(qos: QoS, retain: bool, topic: T, payload: P) -> LastWill { - Self { - qos, - retain, - last_will_properties: LastWillProperties::default(), - topic: topic.as_ref().into(), - payload: Bytes::from(payload.into()), - } - } - pub fn read(qos: QoS, retain: bool, buf: &mut Bytes) -> Result { - let last_will_properties = LastWillProperties::read(buf)?; - let topic = Box::::read(buf)?; - let payload = Bytes::read(buf)?; - - Ok(Self { - qos, - retain, - topic, - payload, - last_will_properties, - }) - } -} - -impl MqttWrite for LastWill { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - self.last_will_properties.write(buf)?; - self.topic.write(buf)?; - self.payload.write(buf)?; - Ok(()) - } -} - -impl WireLength for LastWill { - fn wire_len(&self) -> usize { - let property_len = self.last_will_properties.wire_len(); - - self.topic.wire_len() + self.payload.wire_len() + variable_integer_len(property_len) + property_len - } -} - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct LastWillProperties { - /// 3.1.3.2.2 Will Delay Interval - delay_interval: Option, - /// 3.1.3.2.3 Payload Format Indicator - payload_format_indicator: Option, - /// 3.1.3.2.4 Message Expiry Interval - message_expiry_interval: Option, - /// 3.1.3.2.5 Content Type - content_type: Option>, - /// 3.1.3.2.6 Response Topic - response_topic: Option>, - /// 3.1.3.2.7 Correlation Data - correlation_data: Option, - /// 3.1.3.2.8 User Property - user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for LastWillProperties { - fn read(buf: &mut Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut property_data)? { - PropertyType::WillDelayInterval => { - if properties.delay_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::WillDelayInterval)); - } - properties.delay_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::PayloadFormatIndicator => { - if properties.payload_format_indicator.is_none() { - return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); - } - properties.payload_format_indicator = Some(u8::read(&mut property_data)?); - } - PropertyType::MessageExpiryInterval => { - if properties.message_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); - } - properties.message_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::ContentType => { - if properties.content_type.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); - } - properties.content_type = Some(Box::::read(&mut property_data)?); - } - PropertyType::ResponseTopic => { - if properties.response_topic.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); - } - properties.response_topic = Some(Box::::read(&mut property_data)?); - } - PropertyType::CorrelationData => { - if properties.correlation_data.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); - } - properties.correlation_data = Some(Bytes::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for LastWillProperties { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(delay_interval) = self.delay_interval { - PropertyType::WillDelayInterval.write(buf)?; - buf.put_u32(delay_interval); - } - if let Some(payload_format_indicator) = self.payload_format_indicator { - PropertyType::PayloadFormatIndicator.write(buf)?; - buf.put_u8(payload_format_indicator); - } - if let Some(message_expiry_interval) = self.message_expiry_interval { - PropertyType::MessageExpiryInterval.write(buf)?; - buf.put_u32(message_expiry_interval); - } - if let Some(content_type) = &self.content_type { - PropertyType::ContentType.write(buf)?; - content_type.write(buf)?; - } - if let Some(response_topic) = &self.response_topic { - PropertyType::ResponseTopic.write(buf)?; - response_topic.write(buf)?; - } - if let Some(correlation_data) = &self.correlation_data { - PropertyType::CorrelationData.write(buf)?; - correlation_data.write(buf)?; - } - if !self.user_properties.is_empty() { - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - } - Ok(()) - } -} - -impl WireLength for LastWillProperties { - fn wire_len(&self) -> usize { - let mut len: usize = 0; - - if self.delay_interval.is_some() { - len += 5; - } - if self.payload_format_indicator.is_some() { - len += 2; - } - if self.message_expiry_interval.is_some() { - len += 5; - } - // +1 for the property type - len += self.content_type.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); - len += self.response_topic.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); - len += self.correlation_data.as_ref().map_or_else(|| 0, |b| b.wire_len() + 1); - for (key, value) in &self.user_properties { - len += key.wire_len() + value.wire_len() + 1; - } - - len - } -} - -#[cfg(test)] -mod tests { - use crate::packets::{ - mqtt_traits::{MqttWrite, PacketRead, PacketWrite}, - QoS, - }; - - use super::{Connect, ConnectFlags, LastWill}; - - #[test] - fn read_connect() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - // 0x10, - // 39, // packet type, flags and remaining len - 0x00, - 0x04, - b'M', - b'Q', - b'T', - b'T', - 0x05, - 0b1100_1110, // Connect Flags, username, password, will retain=false, will qos=1, last_will, clean_start - 0x00, // Keep alive = 10 sec - 0x0a, - 0x00, // Length of Connect properties - 0x00, // client_id length - 0x04, - b't', // client_id - b'e', - b's', - b't', - 0x00, // Will properties length - 0x00, // length topic - 0x02, - b'/', // Will topic = '/a' - b'a', - 0x00, // Will payload length - 0x0B, - b'h', // Will payload = 'hello world' - b'e', - b'l', - b'l', - b'o', - b' ', - b'w', - b'o', - b'r', - b'l', - b'd', - 0x00, // length username - 0x04, - b'u', // username = 'user' - b's', - b'e', - b'r', - 0x00, // length password - 0x04, - b'p', // Password = 'pass' - b'a', - b's', - b's', - 0xAB, // extra packets in the stream - 0xCD, - 0xEF, - ]; - - buf.extend_from_slice(packet); - let c = Connect::read(0, 0, buf.into()).unwrap(); - - dbg!(c); - } - - #[test] - fn read_and_write_connect() { - let mut buf = bytes::BytesMut::new(); - let packet = &[ - // 0x10, - // 39, // packet type, flags and remaining len - 0x00, - 0x04, - b'M', - b'Q', - b'T', - b'T', - 0x05, // variable header - 0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session - 0x00, // Keep alive = 10 sec - 0x0a, - 0x00, // Length of Connect properties - 0x00, // client_id length - 0x04, - b't', // client_id - b'e', - b's', - b't', - 0x00, // Will properties length - 0x00, // length topic - 0x02, - b'/', // Will topic = '/a' - b'a', - 0x00, // Will payload length - 0x0B, - b'h', // Will payload = 'hello world' - b'e', - b'l', - b'l', - b'o', - b' ', - b'w', - b'o', - b'r', - b'l', - b'd', - 0x00, // length username - 0x04, - b'u', // username - b's', - b'e', - b'r', - 0x00, // length password - 0x04, - b'p', // payload. password = 'pass' - b'a', - b's', - b's', - ]; - - buf.extend_from_slice(packet); - let c = Connect::read(0, 0, buf.into()).unwrap(); - - let mut write_buf = bytes::BytesMut::new(); - c.write(&mut write_buf).unwrap(); - - assert_eq!(packet.to_vec(), write_buf.to_vec()); - - dbg!(c); - } - - #[test] - fn parsing_last_will() { - let last_will = &[ - 0x00, // Will properties length - 0x00, // length topic - 0x02, b'/', // Will topic = '/a' - b'a', 0x00, // Will payload length - 0x0B, b'h', // Will payload = 'hello world' - b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', - ]; - let mut buf = bytes::Bytes::from_static(last_will); - - assert!(LastWill::read(QoS::AtLeastOnce, false, &mut buf).is_ok()); - } - - #[test] - fn read_and_write_connect2() { - let _packet = [ - 0x10, 0x1d, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, - ]; - - let data = [ - 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, - ]; - - let mut buf = bytes::BytesMut::new(); - buf.extend_from_slice(&data); - - let c = Connect::read(0, 0, buf.into()).unwrap(); - - dbg!(c.clone()); - - let mut write_buf = bytes::BytesMut::new(); - c.write(&mut write_buf).unwrap(); - - assert_eq!(data.to_vec(), write_buf.to_vec()); - } - - #[test] - fn parsing_and_writing_last_will() { - let last_will = &[ - 0x00, // Will properties length - 0x00, // length topic - 0x02, b'/', // Will topic = '/a' - b'a', 0x00, // Will payload length - 0x0B, b'h', // Will payload = 'hello world' - b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', - ]; - let mut buf = bytes::Bytes::from_static(last_will); - - let lw = LastWill::read(QoS::AtLeastOnce, false, &mut buf).unwrap(); - - let mut write_buf = bytes::BytesMut::new(); - lw.write(&mut write_buf).unwrap(); - - assert_eq!(last_will.to_vec(), write_buf.to_vec()); - } - - #[test] - fn connect_flag() { - let byte = 0b1100_1110; - let flags = ConnectFlags::from_u8(byte).unwrap(); - assert_eq!(byte, flags.into_u8().unwrap()); - } -} diff --git a/mqrstt/src/packets/connect/connect_flags.rs b/mqrstt/src/packets/connect/connect_flags.rs new file mode 100644 index 0000000..b327563 --- /dev/null +++ b/mqrstt/src/packets/connect/connect_flags.rs @@ -0,0 +1,90 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::error::ReadError; + +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, +}; + +/// The connect flags describe some information related the session. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct ConnectFlags { + /// Indicates whether to start a new session or continue an existing one. + pub clean_start: bool, + /// Specifies if a Will message is included. + pub will_flag: bool, + /// Defines the Quality of Service level for the Will message. + pub will_qos: QoS, + /// Indicates if the Will message should be retained by the broker. + pub will_retain: bool, + /// Shows if a password is included in the payload. + pub password: bool, + /// Shows if a username is included in the payload. + pub username: bool, +} + +impl ConnectFlags { + pub fn from_u8(value: u8) -> Result { + Ok(Self { + clean_start: ((value & 0b00000010) >> 1) != 0, + will_flag: ((value & 0b00000100) >> 2) != 0, + will_qos: QoS::from_u8((value & 0b00011000) >> 3)?, + will_retain: ((value & 0b00100000) >> 5) != 0, + password: ((value & 0b01000000) >> 6) != 0, + username: ((value & 0b10000000) >> 7) != 0, + }) + } + + pub fn into_u8(&self) -> Result { + let byte = ((self.clean_start as u8) << 1) + | ((self.will_flag as u8) << 2) + | (self.will_qos.into_u8() << 3) + | ((self.will_retain as u8) << 5) + | ((self.password as u8) << 6) + | ((self.username as u8) << 7); + Ok(byte) + } +} + +impl Default for ConnectFlags { + fn default() -> Self { + Self { + clean_start: false, + will_flag: false, + will_qos: QoS::AtMostOnce, + will_retain: false, + password: false, + username: false, + } + } +} + +impl MqttRead for ConnectFlags { + fn read(buf: &mut bytes::Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + let byte = buf.get_u8(); + + ConnectFlags::from_u8(byte) + } +} + +impl MqttAsyncRead for ConnectFlags where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(stream: &mut S) -> impl std::future::Future> { + async move { + let byte = stream.read_u8().await?; + Ok((ConnectFlags::from_u8(byte)?, 1)) + } + } +} + +impl MqttWrite for ConnectFlags { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + buf.put_u8(self.into_u8()?); + Ok(()) + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/connect/connect_properties.rs b/mqrstt/src/packets/connect/connect_properties.rs new file mode 100644 index 0000000..d5f3fa3 --- /dev/null +++ b/mqrstt/src/packets/connect/connect_properties.rs @@ -0,0 +1,236 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::error::ReadError; + +use crate::packets::VariableInteger; +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, +}; + + +// /// Connect Properties +// /// +// /// The wire representation starts with the length of all properties after which +// /// the identifiers and their actual value are given +// /// +// /// 3.1.2.11.1 Property Length +// /// The length of the Properties in the CONNECT packet Variable Header encoded as a Variable Byte Integer. +// /// Followed by all possible connect properties: +crate::packets::macros::define_properties!(ConnectProperties, + SessionExpiryInterval, + ReceiveMaximum, + MaximumPacketSize, + TopicAliasMaximum, + RequestResponseInformation, + RequestProblemInformation, + UserProperty, + AuthenticationMethod, + AuthenticationData +); + +// #[derive(Debug, Default, Clone, PartialEq, Eq)] +// pub struct ConnectProperties { +// /// 3.1.2.11.2 Session Expiry Interval +// /// 17 (0x11) Byte Identifier of the Session Expiry Interval +// pub session_expiry_interval: Option, + +// /// 3.1.2.11.3 Receive Maximum +// /// 33 (0x21) Byte, Identifier of the Receive Maximum +// pub receive_maximum: Option, + +// /// 3.1.2.11.4 Maximum Packet Size +// /// 39 (0x27) Byte, Identifier of the Maximum Packet Size +// pub maximum_packet_size: Option, + +// /// 3.1.2.11.5 Topic Alias Maximum +// /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum +// pub topic_alias_maximum: Option, + +// /// 3.1.2.11.6 Request Response Information +// /// 25 (0x19) Byte, Identifier of the Request Response Information +// pub request_response_information: Option, + +// /// 3.1.2.11.7 Request Problem Information +// /// 23 (0x17) Byte, Identifier of the Request Problem Information +// pub request_problem_information: Option, + +// /// 3.1.2.11.8 User Property +// /// 38 (0x26) Byte, Identifier of the User Property +// pub user_properties: Vec<(Box, Box)>, + +// /// 3.1.2.11.9 Authentication Method +// /// 21 (0x15) Byte, Identifier of the Authentication Method +// pub authentication_method: Option>, + +// /// 3.1.2.11.10 Authentication Data +// /// 22 (0x16) Byte, Identifier of the Authentication Data +// pub authentication_data: Bytes, +// } + +impl MqttWrite for ConnectProperties { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + PropertyType::SessionExpiryInterval.write(buf)?; + buf.put_u32(session_expiry_interval); + } + if let Some(receive_maximum) = self.receive_maximum { + PropertyType::ReceiveMaximum.write(buf)?; + buf.put_u16(receive_maximum); + } + if let Some(maximum_packet_size) = self.maximum_packet_size { + PropertyType::MaximumPacketSize.write(buf)?; + buf.put_u32(maximum_packet_size); + } + if let Some(topic_alias_maximum) = self.topic_alias_maximum { + PropertyType::TopicAliasMaximum.write(buf)?; + buf.put_u16(topic_alias_maximum); + } + if let Some(request_response_information) = self.request_response_information { + PropertyType::RequestResponseInformation.write(buf)?; + buf.put_u8(request_response_information); + } + if let Some(request_problem_information) = self.request_problem_information { + PropertyType::RequestProblemInformation.write(buf)?; + buf.put_u8(request_problem_information); + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + if let Some(authentication_method) = &self.authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if let Some(authentication_data) = &self.authentication_data { + if self.authentication_method.is_none() { + return Err(SerializeError::AuthDataWithoutAuthMethod); + } + PropertyType::AuthenticationData.write(buf)?; + authentication_data.write(buf)?; + } + Ok(()) + } +} + +impl MqttRead for ConnectProperties { + fn read(buf: &mut Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut property_data)? { + PropertyType::SessionExpiryInterval => { + if properties.session_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.session_expiry_interval = Some(property_data.get_u32()); + } + PropertyType::ReceiveMaximum => { + if properties.receive_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum)); + } + properties.receive_maximum = Some(property_data.get_u16()); + } + PropertyType::MaximumPacketSize => { + if properties.maximum_packet_size.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MaximumPacketSize)); + } + properties.maximum_packet_size = Some(property_data.get_u32()); + } + PropertyType::TopicAliasMaximum => { + if properties.topic_alias_maximum.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum)); + } + properties.topic_alias_maximum = Some(property_data.get_u16()); + } + PropertyType::RequestResponseInformation => { + if properties.request_response_information.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation)); + } + properties.request_response_information = Some(property_data.get_u8()); + } + PropertyType::RequestProblemInformation => { + if properties.request_problem_information.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::RequestProblemInformation)); + } + properties.request_problem_information = Some(property_data.get_u8()); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::AuthenticationMethod => { + if properties.authentication_method.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); + } + properties.authentication_method = Some(Box::::read(&mut property_data)?); + } + PropertyType::AuthenticationData => { + if properties.authentication_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); + } + properties.authentication_data = Some(Vec::::read(&mut property_data)?); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), + } + + if property_data.is_empty() { + break; + } + } + + if properties.authentication_data.as_ref().is_some_and(|data| !data.is_empty()) && properties.authentication_method.is_none() { + return Err(DeserializeError::MalformedPacketWithInfo("Authentication data is not empty while authentication method is".to_string())); + } + + Ok(properties) + } +} + +// impl WireLength for ConnectProperties { +// fn wire_len(&self) -> usize { +// let mut len: usize = 0; + +// if self.session_expiry_interval.is_some() { +// len += 1 + 4; +// } +// if self.receive_maximum.is_some() { +// len += 1 + 2; +// } +// if self.maximum_packet_size.is_some() { +// len += 1 + 4; +// } +// if self.topic_alias_maximum.is_some() { +// len += 1 + 2; +// } +// if self.request_response_information.is_some() { +// len += 2; +// } +// if self.request_problem_information.is_some() { +// len += 2; +// } +// for (key, value) in &self.user_properties { +// len += 1; +// len += key.wire_len(); +// len += value.wire_len(); +// } +// if let Some(authentication_method) = &self.authentication_method { +// len += 1 + authentication_method.wire_len(); +// } +// if !self.authentication_data.is_empty() && self.authentication_method.is_some() { +// len += 1 + self.authentication_data.wire_len(); +// } + +// len +// } +// } diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs new file mode 100644 index 0000000..ede6e49 --- /dev/null +++ b/mqrstt/src/packets/connect/last_will.rs @@ -0,0 +1,83 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::error::ReadError; + +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, +}; + +use super::{LastWillProperties, VariableInteger}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LastWill { + /// 3.1.2.6 Will QoS + pub qos: QoS, + /// 3.1.2.7 Will Retain + pub retain: bool, + + /// 3.1.3.2 Will properties + pub last_will_properties: LastWillProperties, + /// 3.1.3.3 Will Topic + pub topic: Box, + /// 3.1.3.4 Will payload + pub payload: Vec, +} + +impl LastWill { + pub fn new, P: Into>>(qos: QoS, retain: bool, topic: T, payload: P) -> LastWill { + Self { + qos, + retain, + last_will_properties: LastWillProperties::default(), + topic: topic.as_ref().into(), + payload: payload.into(), + } + } + pub(crate) fn read(qos: QoS, retain: bool, buf: &mut Bytes) -> Result { + let last_will_properties = LastWillProperties::read(buf)?; + let topic = Box::::read(buf)?; + let payload = Vec::::read(buf)?; + + Ok(Self { + qos, + retain, + topic, + payload, + last_will_properties, + }) + } + pub(crate) async fn async_read(qos: QoS, retain: bool, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> where S: tokio::io::AsyncReadExt + Unpin{ + let (last_will_properties, last_will_properties_read_bytes) = LastWillProperties::async_read(stream).await?; + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + let (payload, payload_read_bytes) = Vec::::async_read(stream).await?; + + let total_read_bytes = last_will_properties_read_bytes + topic_read_bytes + payload_read_bytes; + + Ok((Self { + qos, + retain, + last_will_properties, + topic, + payload, + }, total_read_bytes)) + } +} + +impl MqttWrite for LastWill { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.last_will_properties.write(buf)?; + self.topic.write(buf)?; + self.payload.write(buf)?; + Ok(()) + } +} + +impl WireLength for LastWill { + fn wire_len(&self) -> usize { + let property_len = self.last_will_properties.wire_len(); + + self.topic.wire_len() + self.payload.wire_len() + property_len.variable_integer_len() + property_len + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/connect/last_will_properties.rs b/mqrstt/src/packets/connect/last_will_properties.rs new file mode 100644 index 0000000..988a007 --- /dev/null +++ b/mqrstt/src/packets/connect/last_will_properties.rs @@ -0,0 +1,167 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::error::ReadError; + +use crate::packets::VariableInteger; +use crate::packets::{ + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, +}; + +crate::packets::macros::define_properties!(LastWillProperties, + WillDelayInterval, + PayloadFormatIndicator, + MessageExpiryInterval, + ContentType, + ResponseTopic, + CorrelationData, + UserProperty +); + + +// #[derive(Debug, Default, Clone, PartialEq, Eq)] +// pub struct LastWillProperties { +// /// 3.1.3.2.2 Will Delay Interval +// will_delay_interval: Option, +// /// 3.1.3.2.3 Payload Format Indicator +// payload_format_indicator: Option, +// /// 3.1.3.2.4 Message Expiry Interval +// message_expiry_interval: Option, +// /// 3.1.3.2.5 Content Type +// content_type: Option>, +// /// 3.1.3.2.6 Response Topic +// response_topic: Option>, +// /// 3.1.3.2.7 Correlation Data +// correlation_data: Option, +// /// 3.1.3.2.8 User Property +// user_properties: Vec<(Box, Box)>, +// } + +impl MqttRead for LastWillProperties { + fn read(buf: &mut Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut property_data)? { + PropertyType::WillDelayInterval => { + if properties.will_delay_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::WillDelayInterval)); + } + properties.will_delay_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::PayloadFormatIndicator => { + if properties.payload_format_indicator.is_none() { + return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); + } + properties.payload_format_indicator = Some(u8::read(&mut property_data)?); + } + PropertyType::MessageExpiryInterval => { + if properties.message_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); + } + properties.message_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::ContentType => { + if properties.content_type.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); + } + properties.content_type = Some(Box::::read(&mut property_data)?); + } + PropertyType::ResponseTopic => { + if properties.response_topic.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); + } + properties.response_topic = Some(Box::::read(&mut property_data)?); + } + PropertyType::CorrelationData => { + if properties.correlation_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); + } + properties.correlation_data = Some(Vec::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Connect)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for LastWillProperties { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(delay_interval) = self.will_delay_interval { + PropertyType::WillDelayInterval.write(buf)?; + buf.put_u32(delay_interval); + } + if let Some(payload_format_indicator) = self.payload_format_indicator { + PropertyType::PayloadFormatIndicator.write(buf)?; + buf.put_u8(payload_format_indicator); + } + if let Some(message_expiry_interval) = self.message_expiry_interval { + PropertyType::MessageExpiryInterval.write(buf)?; + buf.put_u32(message_expiry_interval); + } + if let Some(content_type) = &self.content_type { + PropertyType::ContentType.write(buf)?; + content_type.write(buf)?; + } + if let Some(response_topic) = &self.response_topic { + PropertyType::ResponseTopic.write(buf)?; + response_topic.write(buf)?; + } + if let Some(correlation_data) = &self.correlation_data { + PropertyType::CorrelationData.write(buf)?; + correlation_data.write(buf)?; + } + if !self.user_properties.is_empty() { + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + } + Ok(()) + } +} + +// impl WireLength for LastWillProperties { +// fn wire_len(&self) -> usize { +// let mut len: usize = 0; + +// if self.will_delay_interval.is_some() { +// len += 5; +// } +// if self.payload_format_indicator.is_some() { +// len += 2; +// } +// if self.message_expiry_interval.is_some() { +// len += 5; +// } +// // +1 for the property type +// len += self.content_type.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); +// len += self.response_topic.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); +// len += self.correlation_data.as_ref().map_or_else(|| 0, |b| b.wire_len() + 1); +// for (key, value) in &self.user_properties { +// len += key.wire_len() + value.wire_len() + 1; +// } + +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/connect/mod.rs b/mqrstt/src/packets/connect/mod.rs new file mode 100644 index 0000000..f4ba49c --- /dev/null +++ b/mqrstt/src/packets/connect/mod.rs @@ -0,0 +1,496 @@ +mod last_will_properties; +pub use last_will_properties::LastWillProperties; + +mod connect_flags; +pub use connect_flags::ConnectFlags; + +mod connect_properties; +pub use connect_properties::ConnectProperties; + +mod last_will; +pub use last_will::LastWill; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::error::ReadError; + +use super::{ + error::{DeserializeError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, protocol_version::ProtocolVersion, PacketType, PropertyType, QoS, VariableInteger, WireLength +}; + +/// Connect packet send by the client to the server to initialize a connection. +/// +/// Variable Header +/// - Protocol Name and Version: Identifies the MQTT protocol and version. +/// - Connect Flags: Options like clean start, will flag, will QoS, will retain, password flag, and username flag. +/// - Keep Alive Interval: Maximum time interval between messages. +/// - Properties: Optional settings such as session expiry interval, receive maximum, maximum packet size, and topic alias maximum. +/// +/// Payload +/// - Client Identifier: Unique ID for the client. +/// - Will Message: Optional message sent if the client disconnects unexpectedly. +/// - Username and Password: Optional credentials for authentication. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Connect { + pub protocol_version: ProtocolVersion, + + /// 3.1.2.4 Clean Start Flag + pub clean_start: bool, + /// 3.1.2.5 Will Flag through option + pub last_will: Option, + + /// 3.1.2.8 User Name Flag + pub username: Option>, + /// 3.1.2.9 Password Flag + pub password: Option>, + /// 3.1.2.10 Keep Alive + pub keep_alive: u16, + /// 3.1.2.11 CONNECT Properties + pub connect_properties: ConnectProperties, + + /// 3.1.3.1 Client Identifier (ClientID) + pub client_id: Box, +} + +impl Default for Connect { + fn default() -> Self { + Self { + protocol_version: ProtocolVersion::V5, + clean_start: true, + last_will: None, + username: None, + password: None, + keep_alive: 60, + connect_properties: ConnectProperties::default(), + client_id: "MQRSTT".into(), + } + } +} + +impl PacketRead for Connect { + fn read(_: u8, _: usize, mut buf: Bytes) -> Result { + if String::read(&mut buf)? != "MQTT" { + return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_string())); + } + + let protocol_version = ProtocolVersion::read(&mut buf)?; + + let connect_flags = ConnectFlags::read(&mut buf)?; + + let clean_start = connect_flags.clean_start; + let keep_alive = buf.get_u16(); + + let connect_properties = ConnectProperties::read(&mut buf)?; + + let client_id = Box::::read(&mut buf)?; + let mut last_will = None; + if connect_flags.will_flag { + let retain = connect_flags.will_retain; + + last_will = Some(LastWill::read(connect_flags.will_qos, retain, &mut buf)?); + } + + let username = if connect_flags.username { Some(Box::::read(&mut buf)?) } else { None }; + let password = if connect_flags.password { Some(Box::::read(&mut buf)?) } else { None }; + + let connect = Connect { + protocol_version, + clean_start, + last_will, + username, + password, + keep_alive, + connect_properties, + client_id, + }; + + Ok(connect) + } +} + +impl PacketAsyncRead for Connect where S: tokio::io::AsyncReadExt + Unpin { + async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let mut total_read_bytes = 0; + let expected_protocol = [0x00, 0x04, b'M', b'Q', b'T', b'T']; + let mut protocol = [0u8; 6]; + stream.read_exact(&mut protocol).await?; + + if protocol != expected_protocol { + return Err(ReadError::DeserializeError(DeserializeError::MalformedPacketWithInfo(format!("Protocol not MQTT: {:?}", protocol)))); + } + let (protocol_version, _) = ProtocolVersion::async_read(stream).await?; + let (connect_flags, _) = ConnectFlags::async_read(stream).await?; + // Add "MQTT", protocol version and connect flags read bytes + total_read_bytes += 6 + 1 + 1; + + let clean_start = connect_flags.clean_start; + let keep_alive = stream.read_u16().await?; + // Add keep alive read bytes + total_read_bytes += 2; + + let (connect_properties, prop_read_bytes) = ConnectProperties::async_read(stream).await?; + let (client_id, client_read_bytes) = Box::::async_read(stream).await?; + total_read_bytes += prop_read_bytes + client_read_bytes; + + let last_will = if connect_flags.will_flag { + let retain = connect_flags.will_retain; + let (last_will, last_will_read_bytes) = LastWill::async_read(connect_flags.will_qos, retain, stream).await?; + total_read_bytes += last_will_read_bytes; + Some(last_will) + } else { + None + }; + + let (username, username_read_bytes) = if connect_flags.username { + let (username, username_read_bytes) = Box::::async_read(stream).await?; + (Some(username), username_read_bytes) + } else { (None, 0) }; + let (password, password_read_bytes) = if connect_flags.password { + let (password, password_read_bytes) = Box::::async_read(stream).await?; + (Some(password), password_read_bytes) + } else { (None, 0) }; + + total_read_bytes += username_read_bytes + password_read_bytes; + + let connect = Connect { + protocol_version, + clean_start, + last_will, + username, + password, + keep_alive, + connect_properties, + client_id, + }; + Ok((connect, total_read_bytes)) + } +} + +impl PacketWrite for Connect { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + "MQTT".write(buf)?; + + self.protocol_version.write(buf)?; + + let mut connect_flags = ConnectFlags { + clean_start: self.clean_start, + ..Default::default() + }; + + if let Some(last_will) = &self.last_will { + connect_flags.will_flag = true; + connect_flags.will_retain = last_will.retain; + connect_flags.will_qos = last_will.qos; + } + connect_flags.username = self.username.is_some(); + connect_flags.password = self.password.is_some(); + + connect_flags.write(buf)?; + + buf.put_u16(self.keep_alive); + + self.connect_properties.write(buf)?; + + self.client_id.write(buf)?; + + if let Some(last_will) = &self.last_will { + last_will.write(buf)?; + } + if let Some(username) = &self.username { + username.write(buf)?; + } + if let Some(password) = &self.password { + password.write(buf)?; + } + Ok(()) + } +} + +impl WireLength for Connect { + fn wire_len(&self) -> usize { + let mut len = "MQTT".wire_len() + 1 + 1 + 2; // protocol version, connect_flags and keep alive + + len += self.connect_properties.wire_len().variable_integer_len(); + len += self.connect_properties.wire_len(); + + if let Some(last_will) = &self.last_will { + len += last_will.wire_len(); + } + if let Some(username) = &self.username { + len += username.wire_len() + } + if let Some(password) = &self.password { + len += password.wire_len() + } + + len += self.client_id.wire_len(); + + len + } +} + +#[cfg(test)] +mod tests { + use crate::packets::{ + mqtt_trait::{MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + QoS, + }; + + use super::{Connect, ConnectFlags, LastWill}; + + #[test] + fn read_connect() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + // 0x10, + // 39, // packet type, flags and remaining len + 0x00, + 0x04, + b'M', + b'Q', + b'T', + b'T', + 0x05, + 0b1100_1110, // Connect Flags, username, password, will retain=false, will qos=1, last_will, clean_start + 0x00, // Keep alive = 10 sec + 0x0a, + 0x00, // Length of Connect properties + 0x00, // client_id length + 0x04, + b't', // client_id + b'e', + b's', + b't', + 0x00, // Will properties length + 0x00, // length topic + 0x02, + b'/', // Will topic = '/a' + b'a', + 0x00, // Will payload length + 0x0B, + b'h', // Will payload = 'hello world' + b'e', + b'l', + b'l', + b'o', + b' ', + b'w', + b'o', + b'r', + b'l', + b'd', + 0x00, // length username + 0x04, + b'u', // username = 'user' + b's', + b'e', + b'r', + 0x00, // length password + 0x04, + b'p', // Password = 'pass' + b'a', + b's', + b's', + 0xAB, // extra packets in the stream + 0xCD, + 0xEF, + ]; + + buf.extend_from_slice(packet); + let c = Connect::read(0, 0, buf.into()).unwrap(); + + dbg!(c); + } + + #[test] + fn read_and_write_connect() { + let mut buf = bytes::BytesMut::new(); + let packet = &[ + // 0x10, + // 39, // packet type, flags and remaining len + 0x00, + 0x04, + b'M', + b'Q', + b'T', + b'T', + 0x05, // variable header + 0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session + 0x00, // Keep alive = 10 sec + 0x0a, + 0x00, // Length of Connect properties + 0x00, // client_id length + 0x04, + b't', // client_id + b'e', + b's', + b't', + 0x00, // Will properties length + 0x00, // length topic + 0x02, + b'/', // Will topic = '/a' + b'a', + 0x00, // Will payload length + 0x0B, + b'h', // Will payload = 'hello world' + b'e', + b'l', + b'l', + b'o', + b' ', + b'w', + b'o', + b'r', + b'l', + b'd', + 0x00, // length username + 0x04, + b'u', // username + b's', + b'e', + b'r', + 0x00, // length password + 0x04, + b'p', // payload. password = 'pass' + b'a', + b's', + b's', + ]; + + buf.extend_from_slice(packet); + let c = Connect::read(0, 0, buf.into()).unwrap(); + + let mut write_buf = bytes::BytesMut::new(); + c.write(&mut write_buf).unwrap(); + + assert_eq!(packet.to_vec(), write_buf.to_vec()); + + dbg!(c); + } + + #[tokio::test] + async fn read_async_and_write_connect() { + let packet = &[ + // 0x10, + // 39, // packet type, flags and remaining len + 0x00, + 0x04, + b'M', + b'Q', + b'T', + b'T', + 0x05, // variable header + 0b1100_1110, // variable header. +username, +password, -will retain, will qos=1, +last_will, +clean_session + 0x00, // Keep alive = 10 sec + 0x0a, + 0x00, // Length of Connect properties + 0x00, // client_id length + 0x04, + b't', // client_id + b'e', + b's', + b't', + 0x00, // Will properties length + 0x00, // length topic + 0x02, + b'/', // Will topic = '/a' + b'a', + 0x00, // Will payload length + 0x0B, + b'h', // Will payload = 'hello world' + b'e', + b'l', + b'l', + b'o', + b' ', + b'w', + b'o', + b'r', + b'l', + b'd', + 0x00, // length username + 0x04, + b'u', // username + b's', + b'e', + b'r', + 0x00, // length password + 0x04, + b'p', // password = 'pass' + b'a', + b's', + b's', + ]; + + let (c, read_bytes) = Connect::async_read(0, 0, &mut packet.as_slice()).await.unwrap(); + assert_eq!(packet.len(), read_bytes); + + let mut write_buf = bytes::BytesMut::new(); + c.write(&mut write_buf).unwrap(); + + assert_eq!(packet.to_vec(), write_buf.to_vec()); + } + + #[test] + fn parsing_last_will() { + let last_will = &[ + 0x00, // Will properties length + 0x00, // length topic + 0x02, b'/', b'a', // Will topic = '/a' + 0x00, 0x0B, // Will payload length + b'h', b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', // Will payload = 'hello world' + ]; + let mut buf = bytes::Bytes::from_static(last_will); + assert!(LastWill::read(QoS::AtLeastOnce, false, &mut buf).is_ok()); + } + + #[test] + fn read_and_write_connect2() { + let _packet = [ + 0x10, 0x1d, 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, + ]; + + let data = [ + 0x00, 0x04, 0x4d, 0x51, 0x54, 0x54, 0x05, 0x80, 0x00, 0x3c, 0x05, 0x11, 0xff, 0xff, 0xff, 0xff, 0x00, 0x05, 0x39, 0x2e, 0x30, 0x2e, 0x31, 0x00, 0x04, 0x54, 0x65, 0x73, 0x74, + ]; + + let mut buf = bytes::BytesMut::new(); + buf.extend_from_slice(&data); + + let c = Connect::read(0, 0, buf.into()).unwrap(); + + dbg!(c.clone()); + + let mut write_buf = bytes::BytesMut::new(); + c.write(&mut write_buf).unwrap(); + + assert_eq!(data.to_vec(), write_buf.to_vec()); + } + + #[test] + fn parsing_and_writing_last_will() { + let last_will = &[ + 0x00, // Will properties length + 0x00, // length topic + 0x02, b'/', // Will topic = '/a' + b'a', 0x00, // Will payload length + 0x0B, b'h', // Will payload = 'hello world' + b'e', b'l', b'l', b'o', b' ', b'w', b'o', b'r', b'l', b'd', + ]; + let mut buf = bytes::Bytes::from_static(last_will); + + let lw = LastWill::read(QoS::AtLeastOnce, false, &mut buf).unwrap(); + + let mut write_buf = bytes::BytesMut::new(); + lw.write(&mut write_buf).unwrap(); + + assert_eq!(last_will.to_vec(), write_buf.to_vec()); + } + + #[test] + fn connect_flag() { + let byte = 0b1100_1110; + let flags = ConnectFlags::from_u8(byte).unwrap(); + assert_eq!(byte, flags.into_u8().unwrap()); + } +} diff --git a/mqrstt/src/packets/disconnect.rs b/mqrstt/src/packets/disconnect.rs deleted file mode 100644 index e5c70e9..0000000 --- a/mqrstt/src/packets/disconnect.rs +++ /dev/null @@ -1,226 +0,0 @@ -use bytes::BufMut; - -use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - read_variable_integer, - reason_codes::DisconnectReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct Disconnect { - pub reason_code: DisconnectReasonCode, - pub properties: DisconnectProperties, -} - -impl PacketRead for Disconnect { - fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { - let reason_code; - let properties; - if remaining_length == 0 { - reason_code = DisconnectReasonCode::NormalDisconnection; - properties = DisconnectProperties::default(); - } else { - reason_code = DisconnectReasonCode::read(&mut buf)?; - properties = DisconnectProperties::read(&mut buf)?; - } - - Ok(Self { reason_code, properties }) - } -} -impl PacketWrite for Disconnect { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { - self.reason_code.write(buf)?; - self.properties.write(buf)?; - } - Ok(()) - } -} -impl WireLength for Disconnect { - fn wire_len(&self) -> usize { - if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { - let property_len = self.properties.wire_len(); - // reasoncode, length of property length, property length - 1 + variable_integer_len(property_len) + property_len - } else { - 0 - } - } -} - -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct DisconnectProperties { - pub session_expiry_interval: Option, - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, - pub server_reference: Option>, -} - -impl MqttRead for DisconnectProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf).map_err(DeserializeError::from)?; - - let mut properties = Self::default(); - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::try_from(u8::read(&mut property_data)?)? { - PropertyType::SessionExpiryInterval => { - if properties.session_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.session_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(&mut property_data)?); - } - PropertyType::ServerReference => { - if properties.server_reference.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); - } - properties.server_reference = Some(Box::::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Disconnect)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for DisconnectProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(session_expiry_interval) = self.session_expiry_interval { - PropertyType::SessionExpiryInterval.write(buf)?; - buf.put_u32(session_expiry_interval); - } - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, val) in self.user_properties.iter() { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - val.write(buf)?; - } - if let Some(server_refrence) = &self.server_reference { - PropertyType::ServerReference.write(buf)?; - server_refrence.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for DisconnectProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if self.session_expiry_interval.is_some() { - len += 4 + 1; - } - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - len += self.user_properties.iter().fold(0, |acc, (k, v)| acc + k.wire_len() + v.wire_len() + 1); - if let Some(server_refrence) = &self.server_reference { - len += server_refrence.wire_len() + 1; - } - len - } -} - -#[cfg(test)] -mod tests { - use super::*; - - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_write_and_read_disconnect() { - let mut buf = bytes::BytesMut::new(); - let packet = Disconnect { - properties: DisconnectProperties { - session_expiry_interval: Some(123), - reason_string: Some(Box::from("Some reason")), - user_properties: vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ], - server_reference: Some(Box::from("Server reference")), - }, - reason_code: DisconnectReasonCode::NormalDisconnection, - }; - - packet.write(&mut buf).unwrap(); - - let read_packet = Disconnect::read(0, buf.len(), buf.into()).unwrap(); - - assert_eq!(read_packet.properties.session_expiry_interval, Some(123)); - assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); - assert_eq!( - read_packet.properties.user_properties, - vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ] - ); - assert_eq!( - read_packet.properties.server_reference, - Some(Box::from("Server reference")) - ); - } -} - - - #[test] - fn test_write_and_read_disconnect_properties() { - let mut buf = bytes::BytesMut::new(); - let properties = DisconnectProperties { - session_expiry_interval: Some(123), - reason_string: Some(Box::from("Some reason")), - user_properties: vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ], - server_reference: Some(Box::from("Server reference")), - }; - - properties.write(&mut buf).unwrap(); - - let read_properties = DisconnectProperties::read(&mut buf.into()).unwrap(); - - assert_eq!(read_properties.session_expiry_interval, Some(123)); - assert_eq!(read_properties.reason_string, Some(Box::from("Some reason"))); - assert_eq!( - read_properties.user_properties, - vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ] - ); - assert_eq!( - read_properties.server_reference, - Some(Box::from("Server reference")) - ); - } -} \ No newline at end of file diff --git a/mqrstt/src/packets/disconnect/mod.rs b/mqrstt/src/packets/disconnect/mod.rs new file mode 100644 index 0000000..983ca5f --- /dev/null +++ b/mqrstt/src/packets/disconnect/mod.rs @@ -0,0 +1,178 @@ +mod properties; +pub use properties::DisconnectProperties; + +mod reason_code; +pub use reason_code::DisconnectReasonCode; + +use super::{ + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Disconnect { + pub reason_code: DisconnectReasonCode, + pub properties: DisconnectProperties, +} + +impl PacketAsyncRead for Disconnect +where + S: tokio::io::AsyncReadExt + Unpin, +{ + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + if remaining_length == 0 { + Ok((Self { reason_code: DisconnectReasonCode::NormalDisconnection, properties: DisconnectProperties::default() }, 0)) + } else { + let (reason_code, reason_code_read_bytes) = DisconnectReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = DisconnectProperties::async_read(stream).await?; + + Ok((Self { reason_code, properties }, reason_code_read_bytes + properties_read_bytes)) + } + } +} +impl PacketRead for Disconnect { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { + let reason_code; + let properties; + if remaining_length == 0 { + reason_code = DisconnectReasonCode::NormalDisconnection; + properties = DisconnectProperties::default(); + } else { + reason_code = DisconnectReasonCode::read(&mut buf)?; + properties = DisconnectProperties::read(&mut buf)?; + } + + Ok(Self { reason_code, properties }) + } +} +impl PacketWrite for Disconnect { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { + self.reason_code.write(buf)?; + self.properties.write(buf)?; + } + Ok(()) + } +} +impl WireLength for Disconnect { + fn wire_len(&self) -> usize { + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { + let property_len = self.properties.wire_len(); + // reasoncode, length of property length, property length + 1 + property_len.variable_integer_len() + property_len + } else { + 0 + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_write_and_async_read_disconnect() { + let mut buf = bytes::BytesMut::new(); + let packet = Disconnect { + properties: DisconnectProperties { + session_expiry_interval: Some(123), + reason_string: Some(Box::from("Some reason")), + user_properties: vec![ + (Box::from("key1"), Box::from("value1")), + (Box::from("key2"), Box::from("value2")), + ], + server_reference: Some(Box::from("Server reference")), + }, + reason_code: DisconnectReasonCode::NormalDisconnection, + }; + + packet.write(&mut buf).unwrap(); + + let mut stream = &*buf; + + let (read_packet, read_bytes) = Disconnect::async_read(0, buf.len(), &mut stream).await.unwrap(); + + assert_eq!(buf.len(), read_bytes); + assert_eq!(read_packet.properties.session_expiry_interval, Some(123)); + assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); + assert_eq!( + read_packet.properties.user_properties, + vec![ + (Box::from("key1"), Box::from("value1")), + (Box::from("key2"), Box::from("value2")), + ] + ); + assert_eq!( + read_packet.properties.server_reference, + Some(Box::from("Server reference")) + ); + } + + #[test] + fn test_write_and_read_disconnect() { + let mut buf = bytes::BytesMut::new(); + let packet = Disconnect { + properties: DisconnectProperties { + session_expiry_interval: Some(123), + reason_string: Some(Box::from("Some reason")), + user_properties: vec![ + (Box::from("key1"), Box::from("value1")), + (Box::from("key2"), Box::from("value2")), + ], + server_reference: Some(Box::from("Server reference")), + }, + reason_code: DisconnectReasonCode::NormalDisconnection, + }; + + packet.write(&mut buf).unwrap(); + + let read_packet = Disconnect::read(0, buf.len(), buf.into()).unwrap(); + + assert_eq!(read_packet.properties.session_expiry_interval, Some(123)); + assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); + assert_eq!( + read_packet.properties.user_properties, + vec![ + (Box::from("key1"), Box::from("value1")), + (Box::from("key2"), Box::from("value2")), + ] + ); + assert_eq!( + read_packet.properties.server_reference, + Some(Box::from("Server reference")) + ); + } + + #[test] + fn test_write_and_read_disconnect_properties() { + let mut buf = bytes::BytesMut::new(); + let properties = DisconnectProperties { + session_expiry_interval: Some(123), + reason_string: Some(Box::from("Some reason")), + user_properties: vec![ + (Box::from("key1"), Box::from("value1")), + (Box::from("key2"), Box::from("value2")), + ], + server_reference: Some(Box::from("Server reference")), + }; + + properties.write(&mut buf).unwrap(); + + let read_properties = DisconnectProperties::read(&mut buf.into()).unwrap(); + + assert_eq!(read_properties.session_expiry_interval, Some(123)); + assert_eq!(read_properties.reason_string, Some(Box::from("Some reason"))); + assert_eq!( + read_properties.user_properties, + vec![ + (Box::from("key1"), Box::from("value1")), + (Box::from("key2"), Box::from("value2")), + ] + ); + assert_eq!( + read_properties.server_reference, + Some(Box::from("Server reference")) + ); + } +} + diff --git a/mqrstt/src/packets/disconnect/properties.rs b/mqrstt/src/packets/disconnect/properties.rs new file mode 100644 index 0000000..b750082 --- /dev/null +++ b/mqrstt/src/packets/disconnect/properties.rs @@ -0,0 +1,110 @@ +use bytes::BufMut; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{ MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!(DisconnectProperties, + SessionExpiryInterval, + ReasonString, + UserProperty, + ServerReference +); + +// #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +// pub struct DisconnectProperties { +// pub session_expiry_interval: Option, +// pub reason_string: Option>, +// pub user_properties: Vec<(Box, Box)>, +// pub server_reference: Option>, +// } + +impl MqttRead for DisconnectProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; + + let mut properties = Self::default(); + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + loop { + match PropertyType::try_from(u8::read(&mut property_data)?)? { + PropertyType::SessionExpiryInterval => { + if properties.session_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); + } + properties.session_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(&mut property_data)?); + } + PropertyType::ServerReference => { + if properties.server_reference.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ServerReference)); + } + properties.server_reference = Some(Box::::read(&mut property_data)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Disconnect)), + } + + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for DisconnectProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(session_expiry_interval) = self.session_expiry_interval { + PropertyType::SessionExpiryInterval.write(buf)?; + buf.put_u32(session_expiry_interval); + } + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, val) in self.user_properties.iter() { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + val.write(buf)?; + } + if let Some(server_refrence) = &self.server_reference { + PropertyType::ServerReference.write(buf)?; + server_refrence.write(buf)?; + } + Ok(()) + } +} + +// impl WireLength for DisconnectProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if self.session_expiry_interval.is_some() { +// len += 4 + 1; +// } +// if let Some(reason_string) = &self.reason_string { +// len += reason_string.wire_len() + 1; +// } +// len += self.user_properties.iter().fold(0, |acc, (k, v)| acc + k.wire_len() + v.wire_len() + 1); +// if let Some(server_refrence) = &self.server_reference { +// len += server_refrence.wire_len() + 1; +// } +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/disconnect/reason_code.rs b/mqrstt/src/packets/disconnect/reason_code.rs new file mode 100644 index 0000000..6c993de --- /dev/null +++ b/mqrstt/src/packets/disconnect/reason_code.rs @@ -0,0 +1,148 @@ +crate::packets::macros::reason_code!(DisconnectReasonCode, + NormalDisconnection, + DisconnectWithWillMessage, + UnspecifiedError, + MalformedPacket, + ProtocolError, + ImplementationSpecificError, + NotAuthorized, + ServerBusy, + ServerShuttingDown, + KeepAliveTimeout, + SessionTakenOver, + TopicFilterInvalid, + TopicNameInvalid, + ReceiveMaximumExceeded, + TopicAliasInvalid, + PacketTooLarge, + MessageRateTooHigh, + QuotaExceeded, + AdministrativeAction, + PayloadFormatInvalid, + RetainNotSupported, + QosNotSupported, + UseAnotherServer, + ServerMoved, + SharedSubscriptionsNotSupported, + ConnectionRateExceeded, + MaximumConnectTime, + SubscriptionIdentifiersNotSupported, + WildcardSubscriptionsNotSupported +); + +// #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum DisconnectReasonCode { +// #[default] +// NormalDisconnection, +// DisconnectWithWillMessage, +// UnspecifiedError, +// MalformedPacket, +// ProtocolError, +// ImplementationSpecificError, +// NotAuthorized, +// ServerBusy, +// ServerShuttingDown, +// KeepAliveTimeout, +// SessionTakenOver, +// TopicFilterInvalid, +// TopicNameInvalid, +// ReceiveMaximumExceeded, +// TopicAliasInvalid, +// PacketTooLarge, +// MessageRateTooHigh, +// QuotaExceeded, +// AdministrativeAction, +// PayloadFormatInvalid, +// RetainNotSupported, +// QosNotSupported, +// UseAnotherServer, +// ServerMoved, +// SharedSubscriptionsNotSupported, +// ConnectionRateExceeded, +// MaximumConnectTime, +// SubscriptionIdentifiersNotSupported, +// WildcardSubscriptionsNotSupported, +// } + +// impl MqttRead for DisconnectReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(DisconnectReasonCode::NormalDisconnection), +// 0x04 => Ok(DisconnectReasonCode::DisconnectWithWillMessage), +// 0x80 => Ok(DisconnectReasonCode::UnspecifiedError), +// 0x81 => Ok(DisconnectReasonCode::MalformedPacket), +// 0x82 => Ok(DisconnectReasonCode::ProtocolError), +// 0x83 => Ok(DisconnectReasonCode::ImplementationSpecificError), +// 0x87 => Ok(DisconnectReasonCode::NotAuthorized), +// 0x89 => Ok(DisconnectReasonCode::ServerBusy), +// 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), +// 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), +// 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), +// 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), +// 0x90 => Ok(DisconnectReasonCode::TopicNameInvalid), +// 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), +// 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), +// 0x95 => Ok(DisconnectReasonCode::PacketTooLarge), +// 0x96 => Ok(DisconnectReasonCode::MessageRateTooHigh), +// 0x97 => Ok(DisconnectReasonCode::QuotaExceeded), +// 0x98 => Ok(DisconnectReasonCode::AdministrativeAction), +// 0x99 => Ok(DisconnectReasonCode::PayloadFormatInvalid), +// 0x9A => Ok(DisconnectReasonCode::RetainNotSupported), +// 0x9B => Ok(DisconnectReasonCode::QosNotSupported), +// 0x9C => Ok(DisconnectReasonCode::UseAnotherServer), +// 0x9D => Ok(DisconnectReasonCode::ServerMoved), +// 0x9E => Ok(DisconnectReasonCode::SharedSubscriptionsNotSupported), +// 0x9F => Ok(DisconnectReasonCode::ConnectionRateExceeded), +// 0xA0 => Ok(DisconnectReasonCode::MaximumConnectTime), +// 0xA1 => Ok(DisconnectReasonCode::SubscriptionIdentifiersNotSupported), +// 0xA2 => Ok(DisconnectReasonCode::WildcardSubscriptionsNotSupported), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttWrite for DisconnectReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// DisconnectReasonCode::NormalDisconnection => 0x00, +// DisconnectReasonCode::DisconnectWithWillMessage => 0x04, +// DisconnectReasonCode::UnspecifiedError => 0x80, +// DisconnectReasonCode::MalformedPacket => 0x81, +// DisconnectReasonCode::ProtocolError => 0x82, +// DisconnectReasonCode::ImplementationSpecificError => 0x83, +// DisconnectReasonCode::NotAuthorized => 0x87, +// DisconnectReasonCode::ServerBusy => 0x89, +// DisconnectReasonCode::ServerShuttingDown => 0x8B, + +// DisconnectReasonCode::KeepAliveTimeout => 0x8D, +// DisconnectReasonCode::SessionTakenOver => 0x8E, +// DisconnectReasonCode::TopicFilterInvalid => 0x8F, + +// DisconnectReasonCode::TopicNameInvalid => 0x90, +// DisconnectReasonCode::ReceiveMaximumExceeded => 0x93, +// DisconnectReasonCode::TopicAliasInvalid => 0x94, +// DisconnectReasonCode::PacketTooLarge => 0x95, +// DisconnectReasonCode::MessageRateTooHigh => 0x96, +// DisconnectReasonCode::QuotaExceeded => 0x97, +// DisconnectReasonCode::AdministrativeAction => 0x98, +// DisconnectReasonCode::PayloadFormatInvalid => 0x99, +// DisconnectReasonCode::RetainNotSupported => 0x9A, +// DisconnectReasonCode::QosNotSupported => 0x9B, +// DisconnectReasonCode::UseAnotherServer => 0x9C, +// DisconnectReasonCode::ServerMoved => 0x9D, +// DisconnectReasonCode::SharedSubscriptionsNotSupported => 0x9E, +// DisconnectReasonCode::ConnectionRateExceeded => 0x9F, +// DisconnectReasonCode::MaximumConnectTime => 0xA0, +// DisconnectReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, +// DisconnectReasonCode::WildcardSubscriptionsNotSupported => 0xA2, +// }; + +// buf.put_u8(val); + +// Ok(()) +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/macros/mod.rs b/mqrstt/src/packets/macros/mod.rs index bc734e6..79836ca 100644 --- a/mqrstt/src/packets/macros/mod.rs +++ b/mqrstt/src/packets/macros/mod.rs @@ -4,3 +4,11 @@ mod reason_code_macros; pub(crate) use reason_code_macros::*; pub(crate) use properties_macros::*; + +// macro_rules! assert_length { +// ($len:ident, $read:expr) => { +// if len != (read) { +// return Err(DeserializeError::InvalidLength(std::any::type_name::(), len, read)); +// } +// }; +// } \ No newline at end of file diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index c7a2c05..ef1ed6c 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -2,9 +2,9 @@ macro_rules! define_properties { ($name:ident, $($prop_variant:ident),*) => { $crate::packets::macros::properties_struct!(@ $name { $($prop_variant,)* } -> ()); - impl $crate::packets::mqtt_traits::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + Unpin { - async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { - let (len, length_variable_integer) = $crate::packets::read_async_variable_integer(stream).await?; + impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { + let (len, length_variable_integer) = ::read_async_variable_integer(stream).await?; if len == 0 { return Ok((Self::default(), length_variable_integer)); } @@ -13,7 +13,7 @@ macro_rules! define_properties { let mut read_property_bytes = 0; loop { - let (prop, read_bytes) = PropertyType::async_read(stream).await?; + let (prop, read_bytes) = crate::packets::PropertyType::async_read(stream).await?; read_property_bytes += read_bytes; match prop { $( @@ -30,12 +30,12 @@ macro_rules! define_properties { } } - impl $crate::packets::mqtt_traits::WireLength for $name { + impl $crate::packets::mqtt_trait::WireLength for $name { fn wire_len(&self) -> usize { let mut len: usize = 0; $( $crate::packets::macros::properties_wire_length!(self, len , PropertyType::$prop_variant); - )*; + )* len } } @@ -89,12 +89,21 @@ macro_rules! properties_struct { pub correlation_data: Option>, )); ); + ( @ $name:ident { ListSubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.3.2.3.8 Subscription Identifier + /// 11 (0x0B), Identifier of the Subscription Identifier. + /// Multiple Subscription Identifiers used in the Publish packet. + pub subscription_identifiers: Vec, + )); + ); ( @ $name:ident { SubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.8 Subscription Identifier /// 11 (0x0B), Identifier of the Subscription Identifier. - pub subscription_identifier: Vec, + pub subscription_identifier: Option, )); ); ( @ $name:ident { SessionExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -137,15 +146,30 @@ macro_rules! properties_struct { pub authentication_data: Option>, )); ); - // ( @ $name:ident { RequestProblemInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( - // // Missing - // ); - // ( @ $name:ident { WillDelayInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( - // // Missing - // ); - // ( @ $name:ident { RequestResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( - // // Missing - // ); + ( @ $name:ident { RequestProblemInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.1.2.11.7 Request Problem Information + /// 23 (0x17) Byte, Identifier of the Request Problem Information + pub request_problem_information: Option, + )); + ); + ( @ $name:ident { WillDelayInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.1.3.2.2 Request Problem Information + /// 24 (0x18) Byte, Identifier of the Will Delay Interval. + pub will_delay_interval: Option, + )); + ); + ( @ $name:ident { RequestResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + $($result)* + /// 3.1.2.11.6 Request Response Information + /// 25 (0x19) Byte, Identifier of the Request Response Information + pub request_response_information: Option, + )); + ); ( @ $name:ident { ResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( $($result)* @@ -309,9 +333,16 @@ macro_rules! properties_read_matches { }; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifier) => { { - let (prop_body, read_bytes) = $crate::packets::read_async_variable_integer($stream).await?; + let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; $read_property_bytes += read_bytes; - $properties.subscription_identifier.push(prop_body); + $properties.subscription_identifier = Some(prop_body as u32); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ListSubscriptionIdentifier) => { + { + let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_identifiers.push(prop_body as u32); } }; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SessionExpiryInterval) => { @@ -364,16 +395,36 @@ macro_rules! properties_read_matches { $properties.authentication_data = Some(prop_body); } }; - // ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestResponseInformation) => { - // { - // if $properties.authentication_data.is_some() { - // return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation))); - // } - // let (prop_body, read_bytes) = Vec::::async_read($stream).await?; - // $read_property_bytes += read_bytes; - // $properties.authentication_data = Some(prop_body); - // } - // }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestResponseInformation) => { + { + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.request_problem_information = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestProblemInformation) => { + { + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RequestProblemInformation))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.request_problem_information = Some(prop_body); + } + }; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WillDelayInterval) => { + { + if $properties.will_delay_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::WillDelayInterval))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.will_delay_interval = Some(prop_body); + } + }; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseInformation) => { { if $properties.response_info.is_some() { @@ -534,8 +585,13 @@ macro_rules! properties_wire_length{ } }; ($self:ident, $len:ident, PropertyType::SubscriptionIdentifier) => { - for sub_id in &($self.subscription_identifier) { - $len += 1 + $crate::packets::variable_integer_len(*sub_id); + if let Some(sub_id) = &($self.subscription_identifier) { + $len += 1 + crate::packets::primitive::VariableInteger::variable_integer_len(sub_id); + } + }; + ($self:ident, $len:ident, PropertyType::ListSubscriptionIdentifier) => { + for sub_id in &($self.subscription_identifiers) { + $len += 1 + crate::packets::primitive::VariableInteger::variable_integer_len(sub_id); } }; ($self:ident, $len:ident, PropertyType::SessionExpiryInterval) => { @@ -559,13 +615,27 @@ macro_rules! properties_wire_length{ } }; ($self:ident, $len:ident, PropertyType::AuthenticationData) => { - if $self.authentication_data.is_some() && $self.authentication_method.is_some() { - $len += 1 + $self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); + if let Some(authentication_data) = &($self).authentication_data { + if !authentication_data.is_empty() && $self.authentication_method.is_some() { + $len += 1 + authentication_data.wire_len(); + } + } + }; + ($self:ident, $len:ident, PropertyType::RequestProblemInformation) => { + if $self.request_problem_information.is_some() { + $len += 2; + } + }; + ($self:ident, $len:ident, PropertyType::WillDelayInterval) => { + if $self.will_delay_interval.is_some() { + $len += 5; + } + }; + ($self:ident, $len:ident, PropertyType::RequestResponseInformation) => { + if $self.request_response_information.is_some() { + $len += 2; } }; - // ($self:ident, $len:ident, PropertyType::RequestResponseInformation) => { - //Will Delay Interval - // ($self:ident, $len:ident, PropertyType::RequestResponseInformation) => { ($self:ident, $len:ident, PropertyType::ResponseInformation) => { if let Some(response_info) = &($self.response_info) { $len += 1 + response_info.wire_len(); diff --git a/mqrstt/src/packets/macros/reason_code_macros.rs b/mqrstt/src/packets/macros/reason_code_macros.rs index 6afe8be..abb1f37 100644 --- a/mqrstt/src/packets/macros/reason_code_macros.rs +++ b/mqrstt/src/packets/macros/reason_code_macros.rs @@ -6,7 +6,7 @@ macro_rules! reason_code { $($code),* } - impl $crate::packets::mqtt_traits::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + std::marker::Unpin{ + impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + std::marker::Unpin{ async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { let input = stream.read_u8().await?; let res = $crate::packets::macros::reason_code_match!(@ $name, input, { @@ -16,11 +16,12 @@ macro_rules! reason_code { } } - impl $crate::packets::mqtt_traits::MqttRead for $name { - fn read(buf: &mut bytes::Bytes) -> Result { + impl $crate::packets::mqtt_trait::MqttRead for $name { + fn read(buf: &mut bytes::Bytes) -> Result { if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + return Err($crate::packets::error::DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } + use bytes::Buf; let res = buf.get_u8(); $crate::packets::macros::reason_code_match!(@ $name, res, { $($code,)* @@ -28,11 +29,12 @@ macro_rules! reason_code { } } - impl $crate::packets::mqtt_traits::MqttWrite for $name { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + impl $crate::packets::mqtt_trait::MqttWrite for $name { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), $crate::packets::error::SerializeError> { let val = $crate::packets::macros::reason_code_match_write!(@ $name, buf, self, { $($code,)* } -> ()); + use bytes::BufMut; buf.put_u8(val); Ok(()) } @@ -45,7 +47,7 @@ macro_rules! reason_code_match { ( @ $name:ident, $input:ident, { } -> ($($result:tt)*) ) => ( match $input { $($result)* - t => Err(DeserializeError::UnknownProperty(t)), + t => Err($crate::packets::error::DeserializeError::UnknownProperty(t)), } ); ( @ $name:ident, $input:ident, { Success, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -60,12 +62,42 @@ macro_rules! reason_code_match { 0x00 => Ok($name::NormalDisconnection), )) ); + ( @ $name:ident, $input:ident, { GrantedQoS0, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x00 => Ok($name::GrantedQoS0), + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS1, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x01 => Ok($name::GrantedQoS1), + )) + ); + ( @ $name:ident, $input:ident, { GrantedQoS2, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x02 => Ok($name::GrantedQoS2), + )) + ); ( @ $name:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* 0x04 => Ok($name::DisconnectWithWillMessage), )) ); + ( @ $name:ident, $input:ident, { NoMatchingSubscribers, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x10 => Ok($name::NoMatchingSubscribers), + )) + ); + ( @ $name:ident, $input:ident, { NoSubscriptionExisted, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x11 => Ok($name::NoSubscriptionExisted), + )) + ); ( @ $name:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* @@ -147,7 +179,7 @@ macro_rules! reason_code_match { ( @ $name:ident, $input:ident, { ServerShuttingDown, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* - 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), + 0x8B => Ok($name::ServerShuttingDown), )) ); ( @ $name:ident, $input:ident, { BadAuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -159,19 +191,19 @@ macro_rules! reason_code_match { ( @ $name:ident, $input:ident, { KeepAliveTimeout, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* - 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), + 0x8D => Ok($name::KeepAliveTimeout), )) ); ( @ $name:ident, $input:ident, { SessionTakenOver, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* - 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), + 0x8E => Ok($name::SessionTakenOver), )) ); ( @ $name:ident, $input:ident, { TopicFilterInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* - 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), + 0x8F => Ok($name::TopicFilterInvalid), )) ); ( @ $name:ident, $input:ident, { TopicNameInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -180,16 +212,28 @@ macro_rules! reason_code_match { 0x90 => Ok($name::TopicNameInvalid), )) ); + ( @ $name:ident, $input:ident, { PacketIdentifierInUse, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x91 => Ok($name::PacketIdentifierInUse), + )) + ); + ( @ $name:ident, $input:ident, { PacketIdentifierNotFound, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( + $($result)* + 0x92 => Ok($name::PacketIdentifierNotFound), + )) + ); ( @ $name:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* - 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), + 0x93 => Ok($name::ReceiveMaximumExceeded), )) ); ( @ $name:ident, $input:ident, { TopicAliasInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match!(@ $name, $input, { $($rest)* } -> ( $($result)* - 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), + 0x94 => Ok($name::TopicAliasInvalid), )) ); ( @ $name:ident, $input:ident, { PacketTooLarge, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -299,12 +343,42 @@ macro_rules! reason_code_match_write{ $name::NormalDisconnection => 0x00, )) ); + ( @ $name:ident, $buf:ident, $input:ident, { GrantedQoS0, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::GrantedQoS0 => 0x00, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { GrantedQoS1, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::GrantedQoS1 => 0x01, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { GrantedQoS2, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::GrantedQoS2 => 0x02, + )) + ); ( @ $name:ident, $buf:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( $($result)* $name::DisconnectWithWillMessage => 0x04, )) ); + ( @ $name:ident, $buf:ident, $input:ident, { NoMatchingSubscribers, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::NoMatchingSubscribers => 0x10, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { NoSubscriptionExisted, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::NoSubscriptionExisted => 0x11, + )) + ); ( @ $name:ident, $buf:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( $($result)* @@ -420,6 +494,19 @@ macro_rules! reason_code_match_write{ $name::TopicNameInvalid => 0x90, )) ); + ( @ $name:ident, $buf:ident, $input:ident, { PacketIdentifierInUse, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::PacketIdentifierInUse => 0x91, + )) + ); + ( @ $name:ident, $buf:ident, $input:ident, { PacketIdentifierNotFound, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + $($result)* + $name::PacketIdentifierNotFound => 0x92, + + )) + ); ( @ $name:ident, $buf:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( $($result)* diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 5b31414..78c86e6 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -1,5 +1,5 @@ pub mod error; -pub mod mqtt_traits; +pub mod mqtt_trait; pub mod reason_codes; mod macros; @@ -18,12 +18,13 @@ mod subscribe; mod unsuback; mod unsubscribe; +mod primitive; +pub use primitive::*; + pub use auth::*; pub use connack::*; pub use connect::*; pub use disconnect::*; -use error::ReadError; -use mqtt_traits::MqttAsyncRead; pub use puback::*; pub use pubcomp::*; pub use publish::*; @@ -35,621 +36,12 @@ pub use unsuback::*; pub use unsubscribe::*; use bytes::{Buf, BufMut, Bytes, BytesMut}; -use core::slice::Iter; use std::fmt::Display; use self::error::{DeserializeError, ReadBytes, SerializeError}; -use self::mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}; - -/// Protocol version -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] -pub enum ProtocolVersion { - V5, -} - -impl MqttWrite for ProtocolVersion { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u8(5u8); - Ok(()) - } -} - -impl MqttRead for ProtocolVersion { - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientDataForProtocolVersion); - } - - match buf.get_u8() { - 3 => Err(DeserializeError::UnsupportedProtocolVersion), - 4 => Err(DeserializeError::UnsupportedProtocolVersion), - 5 => Ok(ProtocolVersion::V5), - _ => Err(DeserializeError::UnknownProtocolVersion), - } - } -} - -impl MqttAsyncRead for ProtocolVersion where S: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { - match stream.read_u8().await { - Ok(5) => Ok((ProtocolVersion::V5, 1)), - Ok(4) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), - Ok(3) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), - Ok(_) => Err(ReadError::DeserializeError(DeserializeError::UnknownProtocolVersion)), - Err(e) => Err(ReadError::IoError(e)), - } - } -} - -/// Quality of service -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum QoS { - #[default] - AtMostOnce = 0, - AtLeastOnce = 1, - ExactlyOnce = 2, -} -impl QoS { - pub fn from_u8(value: u8) -> Result { - match value { - 0 => Ok(QoS::AtMostOnce), - 1 => Ok(QoS::AtLeastOnce), - 2 => Ok(QoS::ExactlyOnce), - _ => Err(DeserializeError::UnknownQoS(value)), - } - } - pub fn into_u8(self) -> u8 { - match self { - QoS::AtMostOnce => 0, - QoS::AtLeastOnce => 1, - QoS::ExactlyOnce => 2, - } - } -} - -impl MqttRead for QoS { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0 => Ok(QoS::AtMostOnce), - 1 => Ok(QoS::AtLeastOnce), - 2 => Ok(QoS::ExactlyOnce), - q => Err(DeserializeError::UnknownQoS(q)), - } - } -} - -impl MqttAsyncRead for QoS where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - match buf.read_u8().await { - Ok(0) => Ok((QoS::AtMostOnce, 1)), - Ok(1) => Ok((QoS::AtLeastOnce, 1)), - Ok(2) => Ok((QoS::ExactlyOnce, 1)), - Ok(q) => Err(ReadError::DeserializeError(DeserializeError::UnknownQoS(q))), - Err(e) => Err(ReadError::IoError(e)), - } - } -} - -impl MqttWrite for QoS { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - let val = match self { - QoS::AtMostOnce => 0, - QoS::AtLeastOnce => 1, - QoS::ExactlyOnce => 2, - }; - buf.put_u8(val); - Ok(()) - } -} - -impl MqttRead for Box { - #[inline] - fn read(buf: &mut Bytes) -> Result { - let content = Bytes::read(buf)?; - - match String::from_utf8(content.to_vec()) { - Ok(s) => Ok(s.into()), - Err(e) => Err(DeserializeError::Utf8Error(e)), - } - } -} - -impl MqttAsyncRead for Box where S: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { - let (content, read_bytes) = Vec::async_read(stream).await?; - match String::from_utf8(content) { - Ok(s) => Ok((s.into(), read_bytes)), - Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), - } - } -} - -impl MqttWrite for Box { - #[inline(always)] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - self.as_ref().write(buf) - } -} - -impl WireLength for Box { - #[inline(always)] - fn wire_len(&self) -> usize { - self.as_ref().wire_len() - } -} - -impl MqttWrite for &str { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.len() as u16); - buf.extend(self.as_bytes()); - Ok(()) - } -} - -impl WireLength for &str { - #[inline(always)] - fn wire_len(&self) -> usize { - self.len() + 2 - } -} - -impl MqttRead for String { - #[inline] - fn read(buf: &mut Bytes) -> Result { - let content = Bytes::read(buf)?; - - match String::from_utf8(content.to_vec()) { - Ok(s) => Ok(s), - Err(e) => Err(DeserializeError::Utf8Error(e)), - } - } -} - -impl MqttAsyncRead for String where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - let (content, read_bytes) = Bytes::async_read(buf).await?; - match String::from_utf8(content.to_vec()) { - Ok(s) => Ok((s, read_bytes)), - Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), - } - } -} - - -impl MqttWrite for String { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - if self.len() > 65535 { - return Err(SerializeError::StringTooLong(self.len())); - } - - buf.put_u16(self.len() as u16); - buf.extend(self.as_bytes()); - Ok(()) - } -} - -impl WireLength for String { - #[inline(always)] - fn wire_len(&self) -> usize { - self.len() + 2 - } -} - -impl MqttRead for Bytes { - #[inline] - fn read(buf: &mut Bytes) -> Result { - let len = buf.get_u16() as usize; - - if len > buf.len() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - Ok(buf.split_to(len)) - } -} - -impl MqttAsyncRead for Bytes where S: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { - let size = stream.read_u16().await? as usize; - // let mut data = BytesMut::with_capacity(size); - let mut data = Vec::with_capacity(size); - let read_bytes = stream.read_exact(&mut data).await?; - assert_eq!(size, read_bytes); - Ok((data.into(), 2 + size)) - } -} - -impl MqttWrite for Bytes { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.len() as u16); - buf.extend(self); - - Ok(()) - } -} - -impl WireLength for Bytes { - #[inline(always)] - fn wire_len(&self) -> usize { - self.len() + 2 - } -} - -impl MqttRead for Vec { - #[inline] - fn read(buf: &mut Bytes) -> Result { - let len = buf.get_u16() as usize; - - if len > buf.len() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - Ok(buf.split_to(len).into()) - } -} -impl MqttWrite for Vec { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.len() as u16); - buf.extend(self); - - Ok(()) - } -} -impl WireLength for Vec { - #[inline(always)] - fn wire_len(&self) -> usize { - self.len() + 2 - } -} - - - -impl MqttAsyncRead for Vec where S: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { - let size = stream.read_u16().await? as usize; - // let mut data = BytesMut::with_capacity(size); - let mut data = vec![0u8; size]; - let read_bytes = stream.read_exact(&mut data).await?; - assert_eq!(size, read_bytes); - Ok((data, 2 + size)) - } -} - - -impl MqttRead for bool { - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0 => Ok(false), - 1 => Ok(true), - _ => Err(error::DeserializeError::MalformedPacket), - } - } -} - -impl MqttAsyncRead for bool where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - match buf.read_u8().await? { - 0 => Ok((false, 1)), - 1 => Ok((true, 1)), - _ => Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)), - } - } -} - -impl MqttWrite for bool { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - if *self { - buf.put_u8(1); - Ok(()) - } else { - buf.put_u8(0); - Ok(()) - } - } -} - -impl MqttRead for u8 { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - Ok(buf.get_u8()) - } -} - -impl MqttAsyncRead for u8 where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - Ok((buf.read_u8().await?, 1)) - } -} - -impl MqttRead for u16 { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.len() < 2 { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); - } - Ok(buf.get_u16()) - } -} - -impl MqttAsyncRead for u16 where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - Ok((buf.read_u16().await?, 2)) - } -} - -impl MqttWrite for u16 { - #[inline] - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u16(*self); - Ok(()) - } -} - -impl MqttRead for u32 { - #[inline] - fn read(buf: &mut Bytes) -> Result { - if buf.len() < 4 { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); - } - Ok(buf.get_u32()) - } -} -impl MqttAsyncRead for u32 where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - Ok((buf.read_u32().await?, 4)) - } -} - -impl MqttWrite for u32 { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u32(*self); - Ok(()) - } -} - -pub fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { - let mut integer = 0; - let mut length = 0; - - for i in 0..4 { - if let Some(byte) = buf.next() { - length += 1; - integer += (*byte as usize & 0x7f) << (7 * i); - - if (*byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } - } else { - return Err(ReadBytes::InsufficientBytes(1)); - } - } - Err(ReadBytes::Err(DeserializeError::MalformedPacket)) -} - -pub fn read_variable_integer(buf: &mut Bytes) -> Result<(usize, usize), DeserializeError> { - let mut integer = 0; - let mut length = 0; - - for i in 0..4 { - if buf.is_empty() { - return Err(DeserializeError::MalformedPacket); - } - length += 1; - let byte = buf.get_u8(); - - integer += (byte as usize & 0x7f) << (7 * i); - - if (byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } - } - Err(DeserializeError::MalformedPacket) -} - -pub async fn read_async_variable_integer(stream: &mut S) -> Result<(usize, usize), ReadError> { - let mut integer = 0; - let mut length = 0; - - for i in 0..4 { - let byte = stream.read_u8().await?; - length += 1; - - integer += (byte as usize & 0x7f) << (7 * i); - - if (byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } - } - Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) -} - -pub fn write_variable_integer(buf: &mut BytesMut, integer: usize) -> Result { - if integer > 268_435_455 { - return Err(SerializeError::VariableIntegerOverflow(integer)); - } - - let mut write = integer; - - for i in 0..4 { - let mut byte = (write % 128) as u8; - write /= 128; - if write > 0 { - byte |= 128; - } - buf.put_u8(byte); - if write == 0 { - return Ok(i+1); - } - } - Err(SerializeError::VariableIntegerOverflow(integer)) -} - -pub fn variable_integer_len(integer: usize) -> usize { - if integer >= 2_097_152 { - 4 - } else if integer >= 16_384 { - 3 - } else if integer >= 128 { - 2 - } else { - 1 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PropertyType { - PayloadFormatIndicator = 1, - MessageExpiryInterval = 2, - ContentType = 3, - ResponseTopic = 8, - CorrelationData = 9, - SubscriptionIdentifier = 11, - SessionExpiryInterval = 17, - AssignedClientIdentifier = 18, - ServerKeepAlive = 19, - AuthenticationMethod = 21, - AuthenticationData = 22, - RequestProblemInformation = 23, - WillDelayInterval = 24, - RequestResponseInformation = 25, - ResponseInformation = 26, - ServerReference = 28, - ReasonString = 31, - ReceiveMaximum = 33, - TopicAliasMaximum = 34, - TopicAlias = 35, - MaximumQos = 36, - RetainAvailable = 37, - UserProperty = 38, - MaximumPacketSize = 39, - WildcardSubscriptionAvailable = 40, - SubscriptionIdentifierAvailable = 41, - SharedSubscriptionAvailable = 42, -} - -impl TryFrom for PropertyType { - type Error = DeserializeError; - - fn try_from(value: u8) -> Result { - match value { - 1 => Ok(Self::PayloadFormatIndicator), - 2 => Ok(Self::MessageExpiryInterval), - 3 => Ok(Self::ContentType), - 8 => Ok(Self::ResponseTopic), - 9 => Ok(Self::CorrelationData), - 11 => Ok(Self::SubscriptionIdentifier), - 17 => Ok(Self::SessionExpiryInterval), - 18 => Ok(Self::AssignedClientIdentifier), - 19 => Ok(Self::ServerKeepAlive), - 21 => Ok(Self::AuthenticationMethod), - 22 => Ok(Self::AuthenticationData), - 23 => Ok(Self::RequestProblemInformation), - 24 => Ok(Self::WillDelayInterval), - 25 => Ok(Self::RequestResponseInformation), - 26 => Ok(Self::ResponseInformation), - 28 => Ok(Self::ServerReference), - 31 => Ok(Self::ReasonString), - 33 => Ok(Self::ReceiveMaximum), - 34 => Ok(Self::TopicAliasMaximum), - 35 => Ok(Self::TopicAlias), - 36 => Ok(Self::MaximumQos), - 37 => Ok(Self::RetainAvailable), - 38 => Ok(Self::UserProperty), - 39 => Ok(Self::MaximumPacketSize), - 40 => Ok(Self::WildcardSubscriptionAvailable), - 41 => Ok(Self::SubscriptionIdentifierAvailable), - 42 => Ok(Self::SharedSubscriptionAvailable), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl From<&PropertyType> for u8 { - fn from(value: &PropertyType) -> Self { - match value { - PropertyType::PayloadFormatIndicator => 1, - PropertyType::MessageExpiryInterval => 2, - PropertyType::ContentType => 3, - PropertyType::ResponseTopic => 8, - PropertyType::CorrelationData => 9, - PropertyType::SubscriptionIdentifier => 11, - PropertyType::SessionExpiryInterval => 17, - PropertyType::AssignedClientIdentifier => 18, - PropertyType::ServerKeepAlive => 19, - PropertyType::AuthenticationMethod => 21, - PropertyType::AuthenticationData => 22, - PropertyType::RequestProblemInformation => 23, - PropertyType::WillDelayInterval => 24, - PropertyType::RequestResponseInformation => 25, - PropertyType::ResponseInformation => 26, - PropertyType::ServerReference => 28, - PropertyType::ReasonString => 31, - PropertyType::ReceiveMaximum => 33, - PropertyType::TopicAliasMaximum => 34, - PropertyType::TopicAlias => 35, - PropertyType::MaximumQos => 36, - PropertyType::RetainAvailable => 37, - PropertyType::UserProperty => 38, - PropertyType::MaximumPacketSize => 39, - PropertyType::WildcardSubscriptionAvailable => 40, - PropertyType::SubscriptionIdentifierAvailable => 41, - PropertyType::SharedSubscriptionAvailable => 42, - } - } -} - -impl From for u8 { - fn from(value: PropertyType) -> Self { - value as u8 - } -} - -impl MqttRead for PropertyType { - fn read(buf: &mut Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - buf.get_u8().try_into() - } -} - -impl MqttAsyncRead for PropertyType where T: tokio::io::AsyncReadExt + std::marker::Unpin { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - match buf.read_u8().await { - Ok(t) => Ok((t.try_into()?, 1)), - Err(e) => Err(ReadError::IoError(e)), - } - } -} - -impl MqttWrite for PropertyType { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - buf.put_u8(self.into()); - Ok(()) - } -} +use self::mqtt_trait::{PacketRead, PacketWrite, WireLength}; // ==================== Packets ==================== - #[derive(Debug, Clone, PartialEq, Eq)] pub enum Packet { Connect(Connect), @@ -690,17 +82,17 @@ impl Packet { } } - pub fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + pub(crate) fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { match self { Packet::Connect(p) => { buf.put_u8(0b0001_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::ConnAck(p) => { buf.put_u8(0b0010_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::Publish(p) => { @@ -715,32 +107,32 @@ impl Packet { first_byte |= 0b0001; } buf.put_u8(first_byte); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubAck(p) => { buf.put_u8(0b0100_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubRec(p) => { buf.put_u8(0b0101_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubRel(p) => { buf.put_u8(0b0110_0010); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::PubComp(p) => { buf.put_u8(0b0111_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::Subscribe(p) => { buf.put_u8(0b1000_0010); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::SubAck(_) => { @@ -748,7 +140,7 @@ impl Packet { } Packet::Unsubscribe(p) => { buf.put_u8(0b1010_0010); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::UnsubAck(_) => { @@ -765,19 +157,19 @@ impl Packet { } Packet::Disconnect(p) => { buf.put_u8(0b1110_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } Packet::Auth(p) => { buf.put_u8(0b1111_0000); - write_variable_integer(buf, p.wire_len())?; + p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } } Ok(()) } - pub fn read(header: FixedHeader, buf: Bytes) -> Result { + pub(crate) fn read(header: FixedHeader, buf: Bytes) -> Result { let packet = match header.packet_type { PacketType::Connect => Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?), PacketType::ConnAck => Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?), @@ -798,7 +190,7 @@ impl Packet { Ok(packet) } - pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> { + pub(crate) fn read_from_buffer(buffer: &mut BytesMut) -> Result> { let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?; if header.remaining_length + header_length > buffer.len() { return Err(ReadBytes::InsufficientBytes(header.remaining_length + header_length - buffer.len())); @@ -841,42 +233,6 @@ impl Display for Packet { } } -// 2.1.1 Fixed Header -// ``` -// 7 3 0 -// +--------------------------+--------------------------+ -// byte 1 | MQTT Control Packet Type | Flags for Packet type | -// +--------------------------+--------------------------+ -// | Remaining Length | -// +-----------------------------------------------------+ -// -// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 -// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] -pub struct FixedHeader { - pub packet_type: PacketType, - pub flags: u8, - pub remaining_length: usize, -} - -impl FixedHeader { - pub fn read_fixed_header(mut header: Iter) -> Result<(Self, usize), ReadBytes> { - if header.len() < 2 { - return Err(ReadBytes::InsufficientBytes(2 - header.len())); - } - - let mut header_length = 1; - let first_byte = header.next().unwrap(); - - let (packet_type, flags) = PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; - - let (remaining_length, length) = read_fixed_header_rem_len(header)?; - header_length += length; - - Ok((Self { packet_type, flags, remaining_length }, header_length)) - } -} - /// 2.1.2 MQTT Control Packet type #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] pub enum PacketType { @@ -935,7 +291,9 @@ mod tests { use crate::packets::publish::{Publish, PublishProperties}; use crate::packets::pubrel::{PubRel, PubRelProperties}; - use crate::packets::reason_codes::{ConnAckReasonCode, DisconnectReasonCode, PubRelReasonCode}; + use crate::packets::connack::ConnAckReasonCode; + use crate::packets::disconnect::DisconnectReasonCode; + use crate::packets::pubrel::PubRelReasonCode; use crate::packets::Packet; #[test] @@ -1060,7 +418,7 @@ mod tests { topic_alias: None, response_topic: None, correlation_data: Some(b"1212".to_vec()), - subscription_identifier: vec![1], + subscription_identifiers: vec![1], user_properties: vec![], content_type: None, }, diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs new file mode 100644 index 0000000..2bad43c --- /dev/null +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -0,0 +1,106 @@ +mod primitive_impl; +pub use primitive_impl::*; + +use std::{future::Future}; + +use bytes::{Bytes, BytesMut}; + +// mod sealed { +// /// Sealed trait to prevent downstream users from implementing the +// /// [`crate::packets::mqtt_trait::PacketRead`], [`crate::packets::mqtt_trait::PacketWrite`], +// /// [`crate::packets::mqtt_trait::PacketAsyncRead`] [`crate::packets::mqtt_trait::PacketAsyncWrite`], +// /// [`crate::packets::mqtt_trait::MqttRead`], [`crate::packets::mqtt_trait::MqttWrite`] +// /// and [`crate::packets::mqtt_trait::WireLength`] traits. +// pub trait Sealed {} +// impl Sealed for crate::packets::ConnAck {} + +// } + +// pub(crate) trait PacketRead: Sized + sealed::Sealed { +// fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; +// } + +// pub(crate) trait PacketAsyncRead: Sized + sealed::Sealed where S: tokio::io::AsyncReadExt + Unpin { +// fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; +// } + +// pub(crate) trait PacketWrite: Sized + sealed::Sealed { +// fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; +// } + +// pub(crate) trait WireLength: sealed::Sealed { +// fn wire_len(&self) -> usize; +// } + +// pub(crate) trait MqttRead: Sized + sealed::Sealed { +// fn read(buf: &mut Bytes) -> Result; +// } +// pub trait MqttAsyncRead: Sized + sealed::Sealed +// { +// /// Reads `Self` from the provided stream. +// /// Returns the deserialized instance and the number of bytes read from the stream. +// fn async_read(stream: &mut S) -> impl Future>; +// } + + +// pub trait MqttWrite: Sized + sealed::Sealed { +// fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; +// } + +// impl<'a, T> MqttWrite for &'a T +// where +// T: MqttWrite, +// &'a T: sealed::Sealed, +// { +// fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError> { +// ::write(self, buf) +// } +// } + +// pub trait PacketValidation: Sized + sealed::Sealed { +// fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; +// } + +pub(crate) trait PacketRead: Sized { + fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; +} + +pub(crate) trait PacketAsyncRead: Sized where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; +} + +pub(crate) trait PacketWrite: Sized { + fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; +} + +pub(crate) trait WireLength { + fn wire_len(&self) -> usize; +} + +pub(crate) trait MqttRead: Sized { + fn read(buf: &mut Bytes) -> Result; +} +pub(crate) trait MqttAsyncRead: Sized +{ + /// Reads `Self` from the provided stream. + /// Returns the deserialized instance and the number of bytes read from the stream. + fn async_read(stream: &mut S) -> impl Future>; +} + + +pub trait MqttWrite: Sized { + fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; +} + +impl<'a, T> MqttWrite for &'a T +where + T: MqttWrite, +{ + fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError> { + ::write(self, buf) + } +} + +pub trait PacketValidation: Sized { + fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; +} diff --git a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs new file mode 100644 index 0000000..76b9a82 --- /dev/null +++ b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs @@ -0,0 +1,269 @@ +use bytes::{BufMut, Buf, Bytes, BytesMut}; + +use crate::packets::mqtt_trait::{MqttRead, MqttAsyncRead, MqttWrite, WireLength}; +use crate::packets::error::{DeserializeError, ReadError, SerializeError}; + + +impl MqttRead for Box { + #[inline] + fn read(buf: &mut Bytes) -> Result { + let content = Bytes::read(buf)?; + + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok(s.into()), + Err(e) => Err(DeserializeError::Utf8Error(e)), + } + } +} + +impl MqttAsyncRead for Box where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let (content, read_bytes) = Vec::async_read(stream).await?; + match String::from_utf8(content) { + Ok(s) => Ok((s.into(), read_bytes)), + Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), + } + } +} + +impl MqttWrite for Box { + #[inline(always)] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + self.as_ref().write(buf) + } +} + +impl WireLength for Box { + #[inline(always)] + fn wire_len(&self) -> usize { + self.as_ref().wire_len() + } +} + +impl MqttWrite for &str { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.len() as u16); + buf.extend(self.as_bytes()); + Ok(()) + } +} + +impl WireLength for &str { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} + +impl MqttRead for String { + #[inline] + fn read(buf: &mut Bytes) -> Result { + let content = Bytes::read(buf)?; + + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok(s), + Err(e) => Err(DeserializeError::Utf8Error(e)), + } + } +} + +impl MqttAsyncRead for String where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + let (content, read_bytes) = Bytes::async_read(buf).await?; + match String::from_utf8(content.to_vec()) { + Ok(s) => Ok((s, read_bytes)), + Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))), + } + } +} + +impl MqttWrite for String { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + if self.len() > 65535 { + return Err(SerializeError::StringTooLong(self.len())); + } + + buf.put_u16(self.len() as u16); + buf.extend(self.as_bytes()); + Ok(()) + } +} + +impl WireLength for String { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} + +impl MqttRead for Bytes { + #[inline] + fn read(buf: &mut Bytes) -> Result { + let len = buf.get_u16() as usize; + + if len > buf.len() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + Ok(buf.split_to(len)) + } +} +impl MqttAsyncRead for Bytes where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let size = stream.read_u16().await? as usize; + // let mut data = BytesMut::with_capacity(size); + let mut data = Vec::with_capacity(size); + let read_bytes = stream.read_exact(&mut data).await?; + assert_eq!(size, read_bytes); + Ok((data.into(), 2 + size)) + } +} +impl MqttWrite for Bytes { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.len() as u16); + buf.extend(self); + + Ok(()) + } +} +impl WireLength for Bytes { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} + +impl MqttRead for Vec { + #[inline] + fn read(buf: &mut Bytes) -> Result { + let len = buf.get_u16() as usize; + + if len > buf.len() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + Ok(buf.split_to(len).into()) + } +} +impl MqttWrite for Vec { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.len() as u16); + buf.extend(self); + + Ok(()) + } +} +impl WireLength for Vec { + #[inline(always)] + fn wire_len(&self) -> usize { + self.len() + 2 + } +} +impl MqttAsyncRead for Vec where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + let size = stream.read_u16().await? as usize; + // let mut data = BytesMut::with_capacity(size); + let mut data = vec![0u8; size]; + let read_bytes = stream.read_exact(&mut data).await?; + assert_eq!(size, read_bytes); + Ok((data, 2 + size)) + } +} + + +impl MqttRead for bool { + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + match buf.get_u8() { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(crate::packets::error::DeserializeError::MalformedPacket), + } + } +} +impl MqttAsyncRead for bool where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + match buf.read_u8().await? { + 0 => Ok((false, 1)), + 1 => Ok((true, 1)), + _ => Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)), + } + } +} +impl MqttWrite for bool { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + if *self { + buf.put_u8(1); + Ok(()) + } else { + buf.put_u8(0); + Ok(()) + } + } +} + +impl MqttRead for u8 { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + Ok(buf.get_u8()) + } +} +impl MqttAsyncRead for u8 where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u8().await?, 1)) + } +} + +impl MqttRead for u16 { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.len() < 2 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); + } + Ok(buf.get_u16()) + } +} +impl MqttAsyncRead for u16 where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u16().await?, 2)) + } +} +impl MqttWrite for u16 { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(*self); + Ok(()) + } +} + +impl MqttRead for u32 { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.len() < 4 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 4)); + } + Ok(buf.get_u32()) + } +} +impl MqttAsyncRead for u32 where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + Ok((buf.read_u32().await?, 4)) + } +} +impl MqttWrite for u32 { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u32(*self); + Ok(()) + } +} diff --git a/mqrstt/src/packets/mqtt_traits.rs b/mqrstt/src/packets/mqtt_traits.rs deleted file mode 100644 index 72ebb9f..0000000 --- a/mqrstt/src/packets/mqtt_traits.rs +++ /dev/null @@ -1,50 +0,0 @@ -use std::{future::Future, process::Output}; - -use bytes::{Bytes, BytesMut}; - -use super::error::{DeserializeError, ReadError, SerializeError}; - -pub trait PacketRead: Sized { - fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; -} - -pub trait PacketAsyncRead: Sized where S: tokio::io::AsyncReadExt + Unpin { - fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; -} - -pub trait PacketWrite: Sized { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError>; -} - -pub trait WireLength { - fn wire_len(&self) -> usize; -} - -pub trait MqttRead: Sized { - fn read(buf: &mut Bytes) -> Result; -} -pub trait MqttAsyncRead: Sized -// where S: tokio::io::AsyncReadExt + Unpin -{ - /// Reads `Self` from the provided stream. - /// Returns the deserialized instance and the number of bytes read from the stream. - fn async_read(stream: &mut S) -> impl Future>; -} - - -pub trait MqttWrite: Sized { - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError>; -} - -impl MqttWrite for &T -where - T: MqttWrite, -{ - fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - ::write(self, buf) - } -} - -pub trait PacketValidation: Sized { - fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; -} diff --git a/mqrstt/src/packets/primitive/fixed_header.rs b/mqrstt/src/packets/primitive/fixed_header.rs new file mode 100644 index 0000000..709c565 --- /dev/null +++ b/mqrstt/src/packets/primitive/fixed_header.rs @@ -0,0 +1,40 @@ +use core::slice::Iter; +use crate::packets::{error::{DeserializeError, ReadBytes}, PacketType}; + +use super::read_fixed_header_rem_len; + +// 2.1.1 Fixed Header +// ``` +// 7 3 0 +// +--------------------------+--------------------------+ +// byte 1 | MQTT Control Packet Type | Flags for Packet type | +// +--------------------------+--------------------------+ +// | Remaining Length | +// +-----------------------------------------------------+ +// +// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 +// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub(crate) struct FixedHeader { + pub packet_type: PacketType, + pub flags: u8, + pub remaining_length: usize, +} + +impl FixedHeader { + pub fn read_fixed_header(mut header: Iter) -> Result<(Self, usize), ReadBytes> { + if header.len() < 2 { + return Err(ReadBytes::InsufficientBytes(2 - header.len())); + } + + let mut header_length = 1; + let first_byte = header.next().unwrap(); + + let (packet_type, flags) = PacketType::from_first_byte(*first_byte).map_err(ReadBytes::Err)?; + + let (remaining_length, length) = read_fixed_header_rem_len(header)?; + header_length += length; + + Ok((Self { packet_type, flags, remaining_length }, header_length)) + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/primitive/mod.rs b/mqrstt/src/packets/primitive/mod.rs new file mode 100644 index 0000000..1b098a5 --- /dev/null +++ b/mqrstt/src/packets/primitive/mod.rs @@ -0,0 +1,13 @@ +mod fixed_header; +pub(crate) use fixed_header::FixedHeader; + +pub mod protocol_version; + +mod property_type; +pub(crate) use property_type::PropertyType; + +mod variable_integer; +pub(crate) use variable_integer::*; + +mod qos; +pub use qos::QoS; diff --git a/mqrstt/src/packets/primitive/property_type.rs b/mqrstt/src/packets/primitive/property_type.rs new file mode 100644 index 0000000..1091293 --- /dev/null +++ b/mqrstt/src/packets/primitive/property_type.rs @@ -0,0 +1,140 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::{error::{DeserializeError, ReadError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PropertyType { + PayloadFormatIndicator = 1, + MessageExpiryInterval = 2, + ContentType = 3, + ResponseTopic = 8, + CorrelationData = 9, + SubscriptionIdentifier = 11, + /// Alternative to SubscriptionIdentifier, dummy property to encode multiple subscription identifiers + ListSubscriptionIdentifier, + SessionExpiryInterval = 17, + AssignedClientIdentifier = 18, + ServerKeepAlive = 19, + AuthenticationMethod = 21, + AuthenticationData = 22, + RequestProblemInformation = 23, + WillDelayInterval = 24, + RequestResponseInformation = 25, + ResponseInformation = 26, + ServerReference = 28, + ReasonString = 31, + ReceiveMaximum = 33, + TopicAliasMaximum = 34, + TopicAlias = 35, + MaximumQos = 36, + RetainAvailable = 37, + UserProperty = 38, + MaximumPacketSize = 39, + WildcardSubscriptionAvailable = 40, + SubscriptionIdentifierAvailable = 41, + SharedSubscriptionAvailable = 42, +} + +impl TryFrom for PropertyType { + type Error = DeserializeError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(Self::PayloadFormatIndicator), + 2 => Ok(Self::MessageExpiryInterval), + 3 => Ok(Self::ContentType), + 8 => Ok(Self::ResponseTopic), + 9 => Ok(Self::CorrelationData), + 11 => Ok(Self::SubscriptionIdentifier), + 17 => Ok(Self::SessionExpiryInterval), + 18 => Ok(Self::AssignedClientIdentifier), + 19 => Ok(Self::ServerKeepAlive), + 21 => Ok(Self::AuthenticationMethod), + 22 => Ok(Self::AuthenticationData), + 23 => Ok(Self::RequestProblemInformation), + 24 => Ok(Self::WillDelayInterval), + 25 => Ok(Self::RequestResponseInformation), + 26 => Ok(Self::ResponseInformation), + 28 => Ok(Self::ServerReference), + 31 => Ok(Self::ReasonString), + 33 => Ok(Self::ReceiveMaximum), + 34 => Ok(Self::TopicAliasMaximum), + 35 => Ok(Self::TopicAlias), + 36 => Ok(Self::MaximumQos), + 37 => Ok(Self::RetainAvailable), + 38 => Ok(Self::UserProperty), + 39 => Ok(Self::MaximumPacketSize), + 40 => Ok(Self::WildcardSubscriptionAvailable), + 41 => Ok(Self::SubscriptionIdentifierAvailable), + 42 => Ok(Self::SharedSubscriptionAvailable), + t => Err(DeserializeError::UnknownProperty(t)), + } + } +} + +impl From<&PropertyType> for u8 { + fn from(value: &PropertyType) -> Self { + match value { + PropertyType::PayloadFormatIndicator => 1, + PropertyType::MessageExpiryInterval => 2, + PropertyType::ContentType => 3, + PropertyType::ResponseTopic => 8, + PropertyType::CorrelationData => 9, + PropertyType::SubscriptionIdentifier => 11, + PropertyType::ListSubscriptionIdentifier => 11, + PropertyType::SessionExpiryInterval => 17, + PropertyType::AssignedClientIdentifier => 18, + PropertyType::ServerKeepAlive => 19, + PropertyType::AuthenticationMethod => 21, + PropertyType::AuthenticationData => 22, + PropertyType::RequestProblemInformation => 23, + PropertyType::WillDelayInterval => 24, + PropertyType::RequestResponseInformation => 25, + PropertyType::ResponseInformation => 26, + PropertyType::ServerReference => 28, + PropertyType::ReasonString => 31, + PropertyType::ReceiveMaximum => 33, + PropertyType::TopicAliasMaximum => 34, + PropertyType::TopicAlias => 35, + PropertyType::MaximumQos => 36, + PropertyType::RetainAvailable => 37, + PropertyType::UserProperty => 38, + PropertyType::MaximumPacketSize => 39, + PropertyType::WildcardSubscriptionAvailable => 40, + PropertyType::SubscriptionIdentifierAvailable => 41, + PropertyType::SharedSubscriptionAvailable => 42, + } + } +} + +impl From for u8 { + fn from(value: PropertyType) -> Self { + value as u8 + } +} + +impl MqttRead for PropertyType { + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + buf.get_u8().try_into() + } +} + +impl MqttAsyncRead for PropertyType where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + match buf.read_u8().await { + Ok(t) => Ok((t.try_into()?, 1)), + Err(e) => Err(ReadError::IoError(e)), + } + } +} + +impl MqttWrite for PropertyType { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u8(self.into()); + Ok(()) + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/primitive/protocol_version.rs b/mqrstt/src/packets/primitive/protocol_version.rs new file mode 100644 index 0000000..7830cd2 --- /dev/null +++ b/mqrstt/src/packets/primitive/protocol_version.rs @@ -0,0 +1,43 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::{error::{DeserializeError, ReadError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}}; + +/// Protocol version +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] +pub enum ProtocolVersion { + V5, +} + +impl MqttWrite for ProtocolVersion { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u8(5u8); + Ok(()) + } +} + +impl MqttRead for ProtocolVersion { + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientDataForProtocolVersion); + } + + match buf.get_u8() { + 3 => Err(DeserializeError::UnsupportedProtocolVersion), + 4 => Err(DeserializeError::UnsupportedProtocolVersion), + 5 => Ok(ProtocolVersion::V5), + _ => Err(DeserializeError::UnknownProtocolVersion), + } + } +} + +impl MqttAsyncRead for ProtocolVersion where S: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + match stream.read_u8().await { + Ok(5) => Ok((ProtocolVersion::V5, 1)), + Ok(4) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), + Ok(3) => Err(ReadError::DeserializeError(DeserializeError::UnsupportedProtocolVersion)), + Ok(_) => Err(ReadError::DeserializeError(DeserializeError::UnknownProtocolVersion)), + Err(e) => Err(ReadError::IoError(e)), + } + } +} diff --git a/mqrstt/src/packets/primitive/qos.rs b/mqrstt/src/packets/primitive/qos.rs new file mode 100644 index 0000000..f019587 --- /dev/null +++ b/mqrstt/src/packets/primitive/qos.rs @@ -0,0 +1,71 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; + +use crate::packets::{error::{DeserializeError, ReadError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}}; + + +/// Quality of service +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum QoS { + #[default] + AtMostOnce = 0, + AtLeastOnce = 1, + ExactlyOnce = 2, +} +impl QoS { + pub fn from_u8(value: u8) -> Result { + match value { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + _ => Err(DeserializeError::UnknownQoS(value)), + } + } + pub fn into_u8(self) -> u8 { + match self { + QoS::AtMostOnce => 0, + QoS::AtLeastOnce => 1, + QoS::ExactlyOnce => 2, + } + } +} + +impl MqttRead for QoS { + #[inline] + fn read(buf: &mut Bytes) -> Result { + if buf.is_empty() { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); + } + + match buf.get_u8() { + 0 => Ok(QoS::AtMostOnce), + 1 => Ok(QoS::AtLeastOnce), + 2 => Ok(QoS::ExactlyOnce), + q => Err(DeserializeError::UnknownQoS(q)), + } + } +} + +impl MqttAsyncRead for QoS where T: tokio::io::AsyncReadExt + std::marker::Unpin { + async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { + match buf.read_u8().await { + Ok(0) => Ok((QoS::AtMostOnce, 1)), + Ok(1) => Ok((QoS::AtLeastOnce, 1)), + Ok(2) => Ok((QoS::ExactlyOnce, 1)), + Ok(q) => Err(ReadError::DeserializeError(DeserializeError::UnknownQoS(q))), + Err(e) => Err(ReadError::IoError(e)), + } + } +} + +impl MqttWrite for QoS { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + let val = match self { + QoS::AtMostOnce => 0, + QoS::AtLeastOnce => 1, + QoS::ExactlyOnce => 2, + }; + buf.put_u8(val); + Ok(()) + } +} \ No newline at end of file diff --git a/mqrstt/src/packets/primitive/variable_integer.rs b/mqrstt/src/packets/primitive/variable_integer.rs new file mode 100644 index 0000000..c27bf8c --- /dev/null +++ b/mqrstt/src/packets/primitive/variable_integer.rs @@ -0,0 +1,182 @@ +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use core::slice::Iter; +use std::future::Future; +use crate::packets::error::{DeserializeError, ReadBytes, ReadError, SerializeError}; + +pub(crate) fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + if let Some(byte) = buf.next() { + length += 1; + integer += (*byte as usize & 0x7f) << (7 * i); + + if (*byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } else { + return Err(ReadBytes::InsufficientBytes(1)); + } + } + Err(ReadBytes::Err(DeserializeError::MalformedPacket)) +} + +pub(crate) trait VariableInteger: Sized { + fn variable_integer_len(&self) -> usize; + fn write_variable_integer(&self, buf: &mut BytesMut) -> Result; + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError>; + fn read_async_variable_integer(stream: &mut S) -> impl Future>; + +} + + + +impl VariableInteger for usize { + + fn variable_integer_len(&self) -> usize { + if *self >= 2_097_152 { + 4 + } else if *self >= 16_384 { + 3 + } else if *self >= 128 { + 2 + } else { + 1 + } + } + + fn write_variable_integer(&self, buf: &mut BytesMut) -> Result { + if *self > 268_435_455 { + return Err(SerializeError::VariableIntegerOverflow(*self as usize)); + } + + let mut write = *self; + + for i in 0..4 { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + buf.put_u8(byte); + if write == 0 { + return Ok(i+1); + } + } + Err(SerializeError::VariableIntegerOverflow(*self as usize)) + } + + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + if buf.is_empty() { + return Err(DeserializeError::MalformedPacket); + } + length += 1; + let byte = buf.get_u8(); + + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(DeserializeError::MalformedPacket) + } + + fn read_async_variable_integer(stream: &mut S) -> impl Future> { + async move { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; + + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) + } + } +} + +impl VariableInteger for u32 { + fn variable_integer_len(&self) -> usize { + if *self >= 2_097_152 { + 4 + } else if *self >= 16_384 { + 3 + } else if *self >= 128 { + 2 + } else { + 1 + } + } + + fn write_variable_integer(&self, buf: &mut BytesMut) -> Result { + if *self > 268_435_455 { + return Err(SerializeError::VariableIntegerOverflow(*self as usize)); + } + + let mut write = *self; + + for i in 0..4 { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + buf.put_u8(byte); + if write == 0 { + return Ok(i+1); + } + } + Err(SerializeError::VariableIntegerOverflow(*self as usize)) + } + + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + if buf.is_empty() { + return Err(DeserializeError::MalformedPacket); + } + length += 1; + let byte = buf.get_u8(); + + integer += (byte as u32 & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(DeserializeError::MalformedPacket) + } + + fn read_async_variable_integer(stream: &mut S) -> impl Future> { + async move { + let mut integer = 0; + let mut length = 0; + + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; + + integer += (byte as u32 & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) + } + } +} diff --git a/mqrstt/src/packets/puback.rs b/mqrstt/src/packets/puback/mod.rs similarity index 75% rename from mqrstt/src/packets/puback.rs rename to mqrstt/src/packets/puback/mod.rs index 69b5c63..4420837 100644 --- a/mqrstt/src/packets/puback.rs +++ b/mqrstt/src/packets/puback/mod.rs @@ -1,11 +1,12 @@ +mod reason_code; +pub use reason_code::PubAckReasonCode; + use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - read_variable_integer, - reason_codes::PubAckReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, VariableInteger, }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -15,6 +16,31 @@ pub struct PubAck { pub properties: PubAckProperties, } +impl PacketAsyncRead for PubAck where S: tokio::io::AsyncReadExt + Unpin { + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let packet_identifier = stream.read_u16().await?; + if remaining_length == 2 { + Ok((Self { + packet_identifier, + reason_code: PubAckReasonCode::Success, + properties: PubAckProperties::default(), + }, 2)) + } else if remaining_length < 4 { + return Err(crate::packets::error::ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), remaining_length, 4))); + } else { + let (reason_code, reason_code_read_bytes) = PubAckReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubAckProperties::async_read(stream).await?; + + Ok((Self { + packet_identifier, + reason_code, + properties, + }, 2 + reason_code_read_bytes + properties_read_bytes)) + } + + } +} + impl PacketRead for PubAck { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { // reason code and properties are optional if reasoncode is success and properties empty. @@ -69,26 +95,31 @@ impl WireLength for PubAck { } else { let prop_len = self.properties.wire_len(); // pkid, reason code, length of the length of properties and lenght of properties - 3 + variable_integer_len(prop_len) + prop_len + 3 + prop_len.variable_integer_len() + prop_len } } } -#[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -pub struct PubAckProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} +crate::packets::macros::define_properties!(PubAckProperties, + ReasonString, + UserProperty +); -impl PubAckProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} +// #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] +// pub struct PubAckProperties { +// pub reason_string: Option>, +// pub user_properties: Vec<(Box, Box)>, +// } + +// impl PubAckProperties { +// pub fn is_empty(&self) -> bool { +// self.reason_string.is_none() && self.user_properties.is_empty() +// } +// } impl MqttRead for PubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; + let (len, _) = VariableInteger::read_variable_integer(buf)?; if len == 0 { return Ok(Self::default()); @@ -122,7 +153,7 @@ impl MqttWrite for PubAckProperties { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { let len = self.wire_len(); - write_variable_integer(buf, len)?; + len.write_variable_integer(buf)?; if let Some(reason_string) = &self.reason_string { PropertyType::ReasonString.write(buf)?; @@ -138,27 +169,26 @@ impl MqttWrite for PubAckProperties { } } -impl WireLength for PubAckProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } +// impl WireLength for PubAckProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(reason_string) = &self.reason_string { +// len += reason_string.wire_len() + 1; +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } - len - } -} +// len +// } +// } #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, puback::{PubAck, PubAckProperties}, - reason_codes::PubAckReasonCode, - write_variable_integer, PropertyType, + PropertyType, PubAckReasonCode, VariableInteger, }; use bytes::{BufMut, Bytes, BytesMut}; @@ -221,7 +251,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -248,7 +278,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let properties = PubAckProperties::read(&mut buf.clone().into()).unwrap(); diff --git a/mqrstt/src/packets/puback/reason_code.rs b/mqrstt/src/packets/puback/reason_code.rs new file mode 100644 index 0000000..c38f7da --- /dev/null +++ b/mqrstt/src/packets/puback/reason_code.rs @@ -0,0 +1,67 @@ + +crate::packets::macros::reason_code!( + PubAckReasonCode, + Success, + NoMatchingSubscribers, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicNameInvalid, + PacketIdentifierInUse, + QuotaExceeded, + PayloadFormatInvalid +); + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum PubAckReasonCode { +// Success, +// NoMatchingSubscribers, +// UnspecifiedError, +// ImplementationSpecificError, +// NotAuthorized, +// TopicNameInvalid, +// PacketIdentifierInUse, +// QuotaExceeded, +// PayloadFormatInvalid, +// } + +// impl MqttRead for PubAckReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(PubAckReasonCode::Success), +// 0x10 => Ok(PubAckReasonCode::NoMatchingSubscribers), +// 0x80 => Ok(PubAckReasonCode::UnspecifiedError), +// 0x83 => Ok(PubAckReasonCode::ImplementationSpecificError), +// 0x87 => Ok(PubAckReasonCode::NotAuthorized), +// 0x90 => Ok(PubAckReasonCode::TopicNameInvalid), +// 0x91 => Ok(PubAckReasonCode::PacketIdentifierInUse), +// 0x97 => Ok(PubAckReasonCode::QuotaExceeded), +// 0x99 => Ok(PubAckReasonCode::PayloadFormatInvalid), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttWrite for PubAckReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// PubAckReasonCode::Success => 0x00, +// PubAckReasonCode::NoMatchingSubscribers => 0x10, +// PubAckReasonCode::UnspecifiedError => 0x80, +// PubAckReasonCode::ImplementationSpecificError => 0x83, +// PubAckReasonCode::NotAuthorized => 0x87, +// PubAckReasonCode::TopicNameInvalid => 0x90, +// PubAckReasonCode::PacketIdentifierInUse => 0x91, +// PubAckReasonCode::QuotaExceeded => 0x97, +// PubAckReasonCode::PayloadFormatInvalid => 0x99, +// }; + +// buf.put_u8(val); + +// Ok(()) +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubcomp.rs b/mqrstt/src/packets/pubcomp/mod.rs similarity index 68% rename from mqrstt/src/packets/pubcomp.rs rename to mqrstt/src/packets/pubcomp/mod.rs index 119699c..286d05d 100644 --- a/mqrstt/src/packets/pubcomp.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -1,11 +1,18 @@ +mod reason_code; +pub use reason_code::PubCompReasonCode; + + +mod properties; +pub use properties::PubCompProperties; + + + use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - read_variable_integer, - reason_codes::PubCompReasonCode, - write_variable_integer, PacketType, PropertyType, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -52,6 +59,36 @@ impl PacketRead for PubComp { } } +impl PacketAsyncRead for PubComp where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let (packet_identifier, id_read_bytes) = u16::async_read(stream).await?; + if remaining_length == 2 { + return Ok((Self { + packet_identifier, + reason_code: PubCompReasonCode::Success, + properties: PubCompProperties::default(), + }, 2)); + } + // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes + else if remaining_length < 4 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4).into()); + } + + let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; + + assert_eq!(id_read_bytes + reason_code_read_bytes + properties_read_bytes, remaining_length); + + Ok((Self { + packet_identifier, + reason_code, + properties, + }, id_read_bytes + reason_code_read_bytes + properties_read_bytes)) + } + } +} + impl PacketWrite for PubComp { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -80,91 +117,11 @@ impl WireLength for PubComp { } } -#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] -pub struct PubCompProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} - -impl PubCompProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} - -impl MqttRead for PubCompProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - if len == 0 { - return Ok(Self::default()); - } - if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut properties = PubCompProperties::default(); - - loop { - match PropertyType::try_from(u8::read(buf)?)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(buf)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubComp)), - } - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for PubCompProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let len = self.wire_len(); - - write_variable_integer(buf, len)?; - - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)? - } - - Ok(()) - } -} - -impl WireLength for PubCompProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - - len - } -} #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - pubcomp::{PubComp, PubCompProperties}, - reason_codes::PubCompReasonCode, - write_variable_integer, PropertyType, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubcomp::{PubComp, PubCompProperties}, PropertyType, PubCompReasonCode, VariableInteger }; use bytes::{BufMut, Bytes, BytesMut}; @@ -227,7 +184,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -252,7 +209,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let properties = PubCompProperties::read(&mut buf.clone().into()).unwrap(); diff --git a/mqrstt/src/packets/pubcomp/properties.rs b/mqrstt/src/packets/pubcomp/properties.rs new file mode 100644 index 0000000..3320039 --- /dev/null +++ b/mqrstt/src/packets/pubcomp/properties.rs @@ -0,0 +1,89 @@ + +use bytes::BufMut; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + + +crate::packets::macros::define_properties!(PubCompProperties, ReasonString, UserProperty); + +// #[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] +// pub struct PubCompProperties { +// pub reason_string: Option>, +// pub user_properties: Vec<(Box, Box)>, +// } + +// impl PubCompProperties { +// pub fn is_empty(&self) -> bool { +// self.reason_string.is_none() && self.user_properties.is_empty() +// } +// } + +impl MqttRead for PubCompProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + if len == 0 { + return Ok(Self::default()); + } + if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties = PubCompProperties::default(); + + loop { + match PropertyType::try_from(u8::read(buf)?)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(buf)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubComp)), + } + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for PubCompProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + let len = self.wire_len(); + + len.write_variable_integer(buf)?; + + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)? + } + + Ok(()) + } +} + +// impl WireLength for PubCompProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(reason_string) = &self.reason_string { +// len += reason_string.wire_len() + 1; +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } + +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubcomp/reason_code.rs b/mqrstt/src/packets/pubcomp/reason_code.rs new file mode 100644 index 0000000..a5c531a --- /dev/null +++ b/mqrstt/src/packets/pubcomp/reason_code.rs @@ -0,0 +1,5 @@ +crate::packets::macros::reason_code!( + PubCompReasonCode, + Success, + PacketIdentifierNotFound +); diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs new file mode 100644 index 0000000..729532f --- /dev/null +++ b/mqrstt/src/packets/publish/mod.rs @@ -0,0 +1,202 @@ +mod properties; +pub use properties::PublishProperties; + + +use bytes::{BufMut, Bytes}; + +use crate::error::PacketValidationError; +use crate::util::constants::MAXIMUM_TOPIC_SIZE; + +use super::mqtt_trait::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}; +use super::VariableInteger; +use super::{ + error::{DeserializeError, SerializeError}, + PacketType, PropertyType, QoS, +}; + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Publish { + /// 3.3.1.1 dup + pub dup: bool, + /// 3.3.1.2 QoS + pub qos: QoS, + /// 3.3.1.3 retain + pub retain: bool, + + /// 3.3.2.1 Topic Name + /// The Topic Name identifies the information channel to which Payload data is published. + pub topic: Box, + + /// 3.3.2.2 Packet Identifier + /// The Packet Identifier field is only present in PUBLISH packets where the QoS level is 1 or 2. Section 2.2.1 provides more information about Packet Identifiers. + pub packet_identifier: Option, + + /// 3.3.2.3 PUBLISH Properties + pub publish_properties: PublishProperties, + + /// 3.3.3 PUBLISH Payload + pub payload: Bytes, +} + +impl Publish { + pub fn new>(qos: QoS, retain: bool, topic: S, packet_identifier: Option, publish_properties: PublishProperties, payload: Bytes) -> Self { + Self { + dup: false, + qos, + retain, + topic: topic.as_ref().into(), + packet_identifier, + publish_properties, + payload, + } + } + + pub fn payload_to_vec(&self) -> Vec { + self.payload.to_vec() + } +} + +impl PacketRead for Publish { + fn read(flags: u8, _: usize, mut buf: bytes::Bytes) -> Result { + let dup = flags & 0b1000 != 0; + let qos = QoS::from_u8((flags & 0b110) >> 1)?; + let retain = flags & 0b1 != 0; + + let topic = Box::::read(&mut buf)?; + let mut packet_identifier = None; + if qos != QoS::AtMostOnce { + packet_identifier = Some(u16::read(&mut buf)?); + } + + let publish_properties = PublishProperties::read(&mut buf)?; + + Ok(Self { + dup, + qos, + retain, + topic, + packet_identifier, + publish_properties, + payload: buf, + }) + } +} + +impl PacketWrite for Publish { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + self.topic.write(buf)?; + + if let Some(pkid) = self.packet_identifier { + buf.put_u16(pkid); + } + + self.publish_properties.write(buf)?; + + buf.extend(&self.payload); + + Ok(()) + } +} + +impl WireLength for Publish { + fn wire_len(&self) -> usize { + let mut len = self.topic.wire_len(); + if self.packet_identifier.is_some() { + len += 2; + } + + let properties_len = self.publish_properties.wire_len(); + + len += properties_len.variable_integer_len(); + len += properties_len; + len += self.payload.len(); + len + } +} + +impl PacketValidation for Publish { + fn validate(&self, max_packet_size: usize) -> Result<(), PacketValidationError> { + use PacketValidationError::*; + if self.wire_len() > max_packet_size { + Err(MaxPacketSize(self.wire_len())) + } else if self.topic.len() > MAXIMUM_TOPIC_SIZE { + Err(TopicSize(self.topic.len())) + } else { + Ok(()) + } + } +} + + + +#[cfg(test)] +mod tests { + use bytes::{BufMut, BytesMut}; + + use crate::packets::{ + mqtt_trait::{PacketRead, PacketWrite, WireLength}, VariableInteger, + }; + + use super::Publish; + + #[test] + fn test_read_write_properties() { + let first_byte = 0b0011_0100; + + let mut properties = [1, 0, 2].to_vec(); + properties.extend(4_294_967_295u32.to_be_bytes()); + properties.push(35); + properties.extend(3456u16.to_be_bytes()); + properties.push(8); + let resp_topic = "hellogoodbye"; + properties.extend((resp_topic.len() as u16).to_be_bytes()); + properties.extend(resp_topic.as_bytes()); + + let mut buf_one = BytesMut::from( + &[ + 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' + ][..], + ); + buf_one.put_u16(10); + properties.len().write_variable_integer(&mut buf_one).unwrap(); + buf_one.extend(properties); + buf_one.extend( + [ + 0x01, // Payload + 0x02, 0xDE, 0xAD, 0xBE, + ] + .to_vec(), + ); + + let rem_len = buf_one.len(); + + let buf = buf_one.clone(); + + let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); + + let mut result_buf = BytesMut::with_capacity(1000); + p.write(&mut result_buf).unwrap(); + + assert_eq!(buf_one.to_vec(), result_buf.to_vec()) + } + + #[test] + fn test_read_write() { + let first_byte = 0b0011_0000; + let buf_one = &[ + 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' + 0x00, 0x01, 0x02, // payload + 0xDE, 0xAD, 0xBE, + ]; + let rem_len = buf_one.len(); + + let buf = BytesMut::from(&buf_one[..]); + + let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); + + let mut result_buf = BytesMut::new(); + p.write(&mut result_buf).unwrap(); + + assert_eq!(buf_one.to_vec(), result_buf.to_vec()) + } +} diff --git a/mqrstt/src/packets/publish.rs b/mqrstt/src/packets/publish/properties.rs similarity index 54% rename from mqrstt/src/packets/publish.rs rename to mqrstt/src/packets/publish/properties.rs index 38059e4..ecdab62 100644 --- a/mqrstt/src/packets/publish.rs +++ b/mqrstt/src/packets/publish/properties.rs @@ -1,386 +1,201 @@ -use bytes::{BufMut, Bytes}; - -use crate::error::PacketValidationError; -use crate::util::constants::MAXIMUM_TOPIC_SIZE; - -use super::mqtt_traits::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}; -use super::{ - error::{DeserializeError, SerializeError}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, -}; - -#[derive(Debug, Default, Clone, PartialEq, Eq)] -pub struct Publish { - /// 3.3.1.1 dup - pub dup: bool, - /// 3.3.1.2 QoS - pub qos: QoS, - /// 3.3.1.3 retain - pub retain: bool, - - /// 3.3.2.1 Topic Name - /// The Topic Name identifies the information channel to which Payload data is published. - pub topic: Box, - - /// 3.3.2.2 Packet Identifier - /// The Packet Identifier field is only present in PUBLISH packets where the QoS level is 1 or 2. Section 2.2.1 provides more information about Packet Identifiers. - pub packet_identifier: Option, - - /// 3.3.2.3 PUBLISH Properties - pub publish_properties: PublishProperties, - - /// 3.3.3 PUBLISH Payload - pub payload: Bytes, -} - -impl Publish { - pub fn new>(qos: QoS, retain: bool, topic: S, packet_identifier: Option, publish_properties: PublishProperties, payload: Bytes) -> Self { - Self { - dup: false, - qos, - retain, - topic: topic.as_ref().into(), - packet_identifier, - publish_properties, - payload, - } - } - - pub fn payload_to_vec(&self) -> Vec { - self.payload.to_vec() - } -} - -impl PacketRead for Publish { - fn read(flags: u8, _: usize, mut buf: bytes::Bytes) -> Result { - let dup = flags & 0b1000 != 0; - let qos = QoS::from_u8((flags & 0b110) >> 1)?; - let retain = flags & 0b1 != 0; - - let topic = Box::::read(&mut buf)?; - let mut packet_identifier = None; - if qos != QoS::AtMostOnce { - packet_identifier = Some(u16::read(&mut buf)?); - } - - let publish_properties = PublishProperties::read(&mut buf)?; - - Ok(Self { - dup, - qos, - retain, - topic, - packet_identifier, - publish_properties, - payload: buf, - }) - } -} - -impl PacketWrite for Publish { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - self.topic.write(buf)?; - - if let Some(pkid) = self.packet_identifier { - buf.put_u16(pkid); - } - - self.publish_properties.write(buf)?; - - buf.extend(&self.payload); - - Ok(()) - } -} - -impl WireLength for Publish { - fn wire_len(&self) -> usize { - let mut len = self.topic.wire_len(); - if self.packet_identifier.is_some() { - len += 2; - } - - let properties_len = self.publish_properties.wire_len(); - - len += variable_integer_len(properties_len); - len += properties_len; - len += self.payload.len(); - len - } -} - -impl PacketValidation for Publish { - fn validate(&self, max_packet_size: usize) -> Result<(), PacketValidationError> { - use PacketValidationError::*; - if self.wire_len() > max_packet_size { - Err(MaxPacketSize(self.wire_len())) - } else if self.topic.len() > MAXIMUM_TOPIC_SIZE { - Err(TopicSize(self.topic.len())) - } else { - Ok(()) - } - } -} - -super::macros::define_properties!(PublishProperties, - PayloadFormatIndicator, - MessageExpiryInterval, - ContentType, - ResponseTopic, - CorrelationData, - SubscriptionIdentifier, - TopicAlias, - UserProperty -); - - -// #[derive(Debug, Clone, PartialEq, Eq, Default)] -// pub struct PublishProperties { -// /// 3.3.2.3.2 Payload Format Indicator -// /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. -// pub payload_format_indicator: Option, - -// /// 3.3.2.3.3 Message Expiry Interval -// /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. -// pub message_expiry_interval: Option, - -// /// 3.3.2.3.9 Content Type -// /// 3 (0x03) Identifier of the Content Type -// pub content_type: Option>, - -// /// 3.3.2.3.5 Response Topic -// /// 8 (0x08) Byte, Identifier of the Response Topic. -// pub response_topic: Option>, - -// /// 3.3.2.3.6 Correlation Data -// /// 9 (0x09) Byte, Identifier of the Correlation Data. -// pub correlation_data: Option, - -// /// 3.3.2.3.8 Subscription Identifier -// /// 11 (0x0B), Identifier of the Subscription Identifier. -// pub subscription_identifier: Vec, - -// /// 3.3.2.3.4 Topic Alias -// /// 35 (0x23) Byte, Identifier of the Topic Alias. -// pub topic_alias: Option, - -// /// 3.3.2.3.7 User Property -// /// 38 (0x26) Byte, Identifier of the User Property. -// pub user_properties: Vec<(Box, Box)>, - -// } - -impl MqttRead for PublishProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf).map_err(DeserializeError::from)?; - - if len == 0 { - return Ok(Self::default()); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut property_data = buf.split_to(len); - - let mut properties = Self::default(); - - loop { - match PropertyType::try_from(u8::read(&mut property_data)?)? { - PropertyType::PayloadFormatIndicator => { - if properties.payload_format_indicator.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); - } - properties.payload_format_indicator = Some(u8::read(&mut property_data)?); - } - PropertyType::MessageExpiryInterval => { - if properties.message_expiry_interval.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); - } - properties.message_expiry_interval = Some(u32::read(&mut property_data)?); - } - PropertyType::TopicAlias => { - if properties.topic_alias.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAlias)); - } - properties.topic_alias = Some(u16::read(&mut property_data)?); - } - PropertyType::ResponseTopic => { - if properties.response_topic.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); - } - properties.response_topic = Some(Box::::read(&mut property_data)?); - } - PropertyType::CorrelationData => { - if properties.correlation_data.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); - } - properties.correlation_data = Some(Vec::::read(&mut property_data)?); - } - PropertyType::SubscriptionIdentifier => { - properties.subscription_identifier.push(read_variable_integer(&mut property_data)?.0); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), - PropertyType::ContentType => { - if properties.content_type.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); - } - properties.content_type = Some(Box::::read(&mut property_data)?); - } - t => return Err(DeserializeError::UnexpectedProperty(t, PacketType::Publish)), - } - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for PublishProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(payload_format_indicator) = self.payload_format_indicator { - buf.put_u8(PropertyType::PayloadFormatIndicator.into()); - buf.put_u8(payload_format_indicator); - } - if let Some(message_expiry_interval) = self.message_expiry_interval { - buf.put_u8(PropertyType::MessageExpiryInterval.into()); - buf.put_u32(message_expiry_interval); - } - if let Some(topic_alias) = self.topic_alias { - buf.put_u8(PropertyType::TopicAlias.into()); - buf.put_u16(topic_alias); - } - if let Some(response_topic) = &self.response_topic { - buf.put_u8(PropertyType::ResponseTopic.into()); - response_topic.as_ref().write(buf)?; - } - if let Some(correlation_data) = &self.correlation_data { - buf.put_u8(PropertyType::CorrelationData.into()); - correlation_data.write(buf)?; - } - for sub_id in &self.subscription_identifier { - buf.put_u8(PropertyType::SubscriptionIdentifier.into()); - write_variable_integer(buf, *sub_id)?; - } - for (key, val) in &self.user_properties { - buf.put_u8(PropertyType::UserProperty.into()); - key.write(buf)?; - val.write(buf)?; - } - if let Some(content_type) = &self.content_type { - buf.put_u8(PropertyType::ContentType.into()); - content_type.write(buf)?; - } - - Ok(()) - } -} - -// impl WireLength for PublishProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; - -// if self.payload_format_indicator.is_some() { -// len += 2; -// } -// if self.message_expiry_interval.is_some() { -// len += 5; -// } -// if self.topic_alias.is_some() { -// len += 3; -// } -// if let Some(response_topic) = &self.response_topic { -// len += 1 + response_topic.wire_len(); -// } -// if let Some(correlation_data) = &self.correlation_data { -// len += 1 + correlation_data.wire_len(); -// } -// for sub_id in &self.subscription_identifier { -// len += 1 + variable_integer_len(*sub_id); -// } -// for (key, val) in &self.user_properties { -// len += 1 + key.wire_len() + val.wire_len(); -// } -// if let Some(content_type) = &self.content_type { -// len += 1 + content_type.wire_len(); -// } - -// len -// } -// } - -#[cfg(test)] -mod tests { - use bytes::{BufMut, BytesMut}; - - use crate::packets::{ - mqtt_traits::{PacketRead, PacketWrite, WireLength}, - write_variable_integer, - }; - - use super::Publish; - - #[test] - fn test_read_write_properties() { - let first_byte = 0b0011_0100; - - let mut properties = [1, 0, 2].to_vec(); - properties.extend(4_294_967_295u32.to_be_bytes()); - properties.push(35); - properties.extend(3456u16.to_be_bytes()); - properties.push(8); - let resp_topic = "hellogoodbye"; - properties.extend((resp_topic.len() as u16).to_be_bytes()); - properties.extend(resp_topic.as_bytes()); - - let mut buf_one = BytesMut::from( - &[ - 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' - ][..], - ); - buf_one.put_u16(10); - write_variable_integer(&mut buf_one, properties.len()).unwrap(); - buf_one.extend(properties); - buf_one.extend( - [ - 0x01, // Payload - 0x02, 0xDE, 0xAD, 0xBE, - ] - .to_vec(), - ); - - let rem_len = buf_one.len(); - - let buf = buf_one.clone(); - - let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); - - let mut result_buf = BytesMut::with_capacity(1000); - p.write(&mut result_buf).unwrap(); - - assert_eq!(buf_one.to_vec(), result_buf.to_vec()) - } - - #[test] - fn test_read_write() { - let first_byte = 0b0011_0000; - let buf_one = &[ - 0x00, 0x03, b'a', b'/', b'b', // variable header. topic name = 'a/b' - 0x00, 0x01, 0x02, // payload - 0xDE, 0xAD, 0xBE, - ]; - let rem_len = buf_one.len(); - - let buf = BytesMut::from(&buf_one[..]); - - let p = Publish::read(first_byte & 0b0000_1111, rem_len, buf.into()).unwrap(); - - let mut result_buf = BytesMut::new(); - p.write(&mut result_buf).unwrap(); - - assert_eq!(buf_one.to_vec(), result_buf.to_vec()) - } -} +use bytes::BufMut; + +use crate::error::PacketValidationError; +use crate::packets::VariableInteger; +use crate::util::constants::MAXIMUM_TOPIC_SIZE; + +use crate::packets::mqtt_trait::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}; +use crate::packets::{ + error::{DeserializeError, SerializeError}, + PacketType, PropertyType, QoS, +}; + +crate::packets::macros::define_properties!(PublishProperties, + PayloadFormatIndicator, + MessageExpiryInterval, + ContentType, + ResponseTopic, + CorrelationData, + ListSubscriptionIdentifier, + TopicAlias, + UserProperty +); + + +// #[derive(Debug, Clone, PartialEq, Eq, Default)] +// pub struct PublishProperties { +// /// 3.3.2.3.2 Payload Format Indicator +// /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. +// pub payload_format_indicator: Option, + +// /// 3.3.2.3.3 Message Expiry Interval +// /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. +// pub message_expiry_interval: Option, + +// /// 3.3.2.3.9 Content Type +// /// 3 (0x03) Identifier of the Content Type +// pub content_type: Option>, + +// /// 3.3.2.3.5 Response Topic +// /// 8 (0x08) Byte, Identifier of the Response Topic. +// pub response_topic: Option>, + +// /// 3.3.2.3.6 Correlation Data +// /// 9 (0x09) Byte, Identifier of the Correlation Data. +// pub correlation_data: Option, + +// /// 3.3.2.3.8 Subscription Identifier +// /// 11 (0x0B), Identifier of the Subscription Identifier. +// pub subscription_identifier: Vec, + +// /// 3.3.2.3.4 Topic Alias +// /// 35 (0x23) Byte, Identifier of the Topic Alias. +// pub topic_alias: Option, + +// /// 3.3.2.3.7 User Property +// /// 38 (0x26) Byte, Identifier of the User Property. +// pub user_properties: Vec<(Box, Box)>, + +// } + +impl MqttRead for PublishProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; + + if len == 0 { + return Ok(Self::default()); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut property_data = buf.split_to(len); + + let mut properties = Self::default(); + + loop { + match PropertyType::try_from(u8::read(&mut property_data)?)? { + PropertyType::PayloadFormatIndicator => { + if properties.payload_format_indicator.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator)); + } + properties.payload_format_indicator = Some(u8::read(&mut property_data)?); + } + PropertyType::MessageExpiryInterval => { + if properties.message_expiry_interval.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval)); + } + properties.message_expiry_interval = Some(u32::read(&mut property_data)?); + } + PropertyType::TopicAlias => { + if properties.topic_alias.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::TopicAlias)); + } + properties.topic_alias = Some(u16::read(&mut property_data)?); + } + PropertyType::ResponseTopic => { + if properties.response_topic.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic)); + } + properties.response_topic = Some(Box::::read(&mut property_data)?); + } + PropertyType::CorrelationData => { + if properties.correlation_data.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::CorrelationData)); + } + properties.correlation_data = Some(Vec::::read(&mut property_data)?); + } + PropertyType::SubscriptionIdentifier => { + properties.subscription_identifiers.push(VariableInteger::read_variable_integer(&mut property_data)?.0); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(&mut property_data)?, Box::::read(&mut property_data)?)), + PropertyType::ContentType => { + if properties.content_type.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ContentType)); + } + properties.content_type = Some(Box::::read(&mut property_data)?); + } + t => return Err(DeserializeError::UnexpectedProperty(t, PacketType::Publish)), + } + if property_data.is_empty() { + break; + } + } + + Ok(properties) + } +} + +impl MqttWrite for PublishProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + self.wire_len().write_variable_integer(buf)?; + + if let Some(payload_format_indicator) = self.payload_format_indicator { + buf.put_u8(PropertyType::PayloadFormatIndicator.into()); + buf.put_u8(payload_format_indicator); + } + if let Some(message_expiry_interval) = self.message_expiry_interval { + buf.put_u8(PropertyType::MessageExpiryInterval.into()); + buf.put_u32(message_expiry_interval); + } + if let Some(topic_alias) = self.topic_alias { + buf.put_u8(PropertyType::TopicAlias.into()); + buf.put_u16(topic_alias); + } + if let Some(response_topic) = &self.response_topic { + buf.put_u8(PropertyType::ResponseTopic.into()); + response_topic.as_ref().write(buf)?; + } + if let Some(correlation_data) = &self.correlation_data { + buf.put_u8(PropertyType::CorrelationData.into()); + correlation_data.write(buf)?; + } + for sub_id in &self.subscription_identifiers { + buf.put_u8(PropertyType::SubscriptionIdentifier.into()); + sub_id.write_variable_integer(buf)?; + } + for (key, val) in &self.user_properties { + buf.put_u8(PropertyType::UserProperty.into()); + key.write(buf)?; + val.write(buf)?; + } + if let Some(content_type) = &self.content_type { + buf.put_u8(PropertyType::ContentType.into()); + content_type.write(buf)?; + } + + Ok(()) + } +} + +// impl WireLength for PublishProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; + +// if self.payload_format_indicator.is_some() { +// len += 2; +// } +// if self.message_expiry_interval.is_some() { +// len += 5; +// } +// if self.topic_alias.is_some() { +// len += 3; +// } +// if let Some(response_topic) = &self.response_topic { +// len += 1 + response_topic.wire_len(); +// } +// if let Some(correlation_data) = &self.correlation_data { +// len += 1 + correlation_data.wire_len(); +// } +// for sub_id in &self.subscription_identifier { +// len += 1 + *sub_id.variable_integer_len(); +// } +// for (key, val) in &self.user_properties { +// len += 1 + key.wire_len() + val.wire_len(); +// } +// if let Some(content_type) = &self.content_type { +// len += 1 + content_type.wire_len(); +// } + +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrec.rs b/mqrstt/src/packets/pubrec/mod.rs similarity index 68% rename from mqrstt/src/packets/pubrec.rs rename to mqrstt/src/packets/pubrec/mod.rs index 065e421..48ddbe4 100644 --- a/mqrstt/src/packets/pubrec.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -1,11 +1,15 @@ + +mod properties; +pub use properties::PubRecProperties; + +mod reason_code; +pub use reason_code::PubRecReasonCode; + use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - read_variable_integer, - reason_codes::PubRecReasonCode, - write_variable_integer, PacketType, PropertyType, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, PacketType, PropertyType, }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -79,91 +83,10 @@ impl WireLength for PubRec { } } -#[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -pub struct PubRecProperties { - pub reason_string: Option>, - pub user_properties: Vec<(Box, Box)>, -} - -impl PubRecProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} - -impl MqttRead for PubRecProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - if len == 0 { - return Ok(Self::default()); - } - if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut properties = PubRecProperties::default(); - - loop { - match PropertyType::try_from(u8::read(buf)?)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(buf)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRec)), - } - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for PubRecProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let len = self.wire_len(); - - write_variable_integer(buf, len)?; - - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)? - } - - Ok(()) - } -} - -impl WireLength for PubRecProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len() + 1; - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - - len - } -} - #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - pubrec::{PubRec, PubRecProperties}, - reason_codes::PubRecReasonCode, - write_variable_integer, PropertyType, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubrec::{PubRec, PubRecProperties}, PropertyType, PubRecReasonCode, VariableInteger }; use bytes::{BufMut, Bytes, BytesMut}; @@ -226,7 +149,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -253,7 +176,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let properties = PubRecProperties::read(&mut buf.clone().into()).unwrap(); diff --git a/mqrstt/src/packets/pubrec/properties.rs b/mqrstt/src/packets/pubrec/properties.rs new file mode 100644 index 0000000..66b8fcf --- /dev/null +++ b/mqrstt/src/packets/pubrec/properties.rs @@ -0,0 +1,84 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!(PubRecProperties, ReasonString, UserProperty); + +// #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] +// pub struct PubRecProperties { +// pub reason_string: Option>, +// pub user_properties: Vec<(Box, Box)>, +// } + +// impl PubRecProperties { +// pub fn is_empty(&self) -> bool { +// self.reason_string.is_none() && self.user_properties.is_empty() +// } +// } + +impl MqttRead for PubRecProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + if len == 0 { + return Ok(Self::default()); + } + if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties = PubRecProperties::default(); + + loop { + match PropertyType::try_from(u8::read(buf)?)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(buf)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRec)), + } + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for PubRecProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + let len = self.wire_len(); + + len.write_variable_integer(buf)?; + + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)? + } + + Ok(()) + } +} + +// impl WireLength for PubRecProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(reason_string) = &self.reason_string { +// len += reason_string.wire_len() + 1; +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } + +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrec/reason_code.rs b/mqrstt/src/packets/pubrec/reason_code.rs new file mode 100644 index 0000000..327a79c --- /dev/null +++ b/mqrstt/src/packets/pubrec/reason_code.rs @@ -0,0 +1,67 @@ + + +crate::packets::macros::reason_code!(PubRecReasonCode, + Success, + NoMatchingSubscribers, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicNameInvalid, + PacketIdentifierInUse, + QuotaExceeded, + PayloadFormatInvalid +); + + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum PubRecReasonCode { +// Success, +// NoMatchingSubscribers, +// UnspecifiedError, +// ImplementationSpecificError, +// NotAuthorized, +// TopicNameInvalid, +// PacketIdentifierInUse, +// QuotaExceeded, +// PayloadFormatInvalid, +// } + +// impl MqttRead for PubRecReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(PubRecReasonCode::Success), +// 0x10 => Ok(PubRecReasonCode::NoMatchingSubscribers), +// 0x80 => Ok(PubRecReasonCode::UnspecifiedError), +// 0x83 => Ok(PubRecReasonCode::ImplementationSpecificError), +// 0x87 => Ok(PubRecReasonCode::NotAuthorized), +// 0x90 => Ok(PubRecReasonCode::TopicNameInvalid), +// 0x91 => Ok(PubRecReasonCode::PacketIdentifierInUse), +// 0x97 => Ok(PubRecReasonCode::QuotaExceeded), +// 0x99 => Ok(PubRecReasonCode::PayloadFormatInvalid), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttWrite for PubRecReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// PubRecReasonCode::Success => 0x00, +// PubRecReasonCode::NoMatchingSubscribers => 0x10, +// PubRecReasonCode::UnspecifiedError => 0x80, +// PubRecReasonCode::ImplementationSpecificError => 0x83, +// PubRecReasonCode::NotAuthorized => 0x87, +// PubRecReasonCode::TopicNameInvalid => 0x90, +// PubRecReasonCode::PacketIdentifierInUse => 0x91, +// PubRecReasonCode::QuotaExceeded => 0x97, +// PubRecReasonCode::PayloadFormatInvalid => 0x99, +// }; + +// buf.put_u8(val); +// Ok(()) +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrel.rs b/mqrstt/src/packets/pubrel/mod.rs similarity index 70% rename from mqrstt/src/packets/pubrel.rs rename to mqrstt/src/packets/pubrel/mod.rs index caadaf2..c320d6a 100644 --- a/mqrstt/src/packets/pubrel.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -1,7 +1,16 @@ + +mod reason_code; +pub use reason_code::PubRelReasonCode; + +mod properties; +pub use properties::PubRelProperties; + use bytes::BufMut; use super::{ - error::{DeserializeError, ReadError}, mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, read_async_variable_integer, read_variable_integer, reason_codes::PubRelReasonCode, write_variable_integer, PacketType, PropertyType + error::{DeserializeError, ReadError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -108,134 +117,10 @@ impl WireLength for PubRel { } } -// #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -// pub struct PubRelProperties { -// pub reason_string: Option>, -// pub user_properties: Vec<(Box, Box)>, -// } - -super::macros::define_properties!(PubRelProperties, ReasonString, UserProperty); - -impl PubRelProperties { - pub fn is_empty(&self) -> bool { - self.reason_string.is_none() && self.user_properties.is_empty() - } -} - -impl MqttRead for PubRelProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - if len == 0 { - return Ok(Self::default()); - } - if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut properties = PubRelProperties::default(); - - loop { - match PropertyType::try_from(u8::read(buf)?)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); - } - properties.reason_string = Some(Box::::read(buf)?); - } - PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRel)), - } - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -// impl MqttAsyncRead for PubRelProperties where S: tokio::io::AsyncReadExt + Unpin { -// async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { -// let (len, length_variable_integer) = read_async_variable_integer(stream).await?; -// if len == 0 { -// return Ok((Self::default(), length_variable_integer)); -// } - -// let mut properties = PubRelProperties::default(); - -// let mut read_property_bytes = 0; -// loop { -// let (prop, read_bytes) = PropertyType::async_read(stream).await?; -// read_property_bytes += read_bytes; -// match prop { -// PropertyType::ReasonString => { -// if properties.reason_string.is_some() { -// return Err(super::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); -// } -// let (prop_body, read_bytes) = Box::::async_read(stream).await?; -// read_property_bytes += read_bytes; -// properties.reason_string = Some(prop_body); -// } -// PropertyType::UserProperty => { -// let (prop_body_key, read_bytes) = Box::::async_read(stream).await?; -// read_property_bytes += read_bytes; -// let (prop_body_value, read_bytes) = Box::::async_read(stream).await?; -// read_property_bytes += read_bytes; - -// properties.user_properties.push((prop_body_key, prop_body_value)) -// }, -// e => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), -// } -// if read_property_bytes == len { -// break; -// } -// } - -// Ok((properties, length_variable_integer + read_property_bytes)) -// } -// } - -impl MqttWrite for PubRelProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let len = self.wire_len(); - - write_variable_integer(buf, len)?; - - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)? - } - - Ok(()) - } -} - -// impl WireLength for PubRelProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(reason_string) = &self.reason_string { -// len += reason_string.wire_len() + 1; -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } - -// len -// } -// } - #[cfg(test)] mod tests { use crate::packets::{ - mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, - pubrel::{PubRel, PubRelProperties}, - reason_codes::PubRelReasonCode, - write_variable_integer, PropertyType, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, pubrel::{PubRel, PubRelProperties}, PropertyType, PubRelReasonCode, VariableInteger }; use bytes::{Buf, BufMut, Bytes, BytesMut}; use tokio::{io::ReadBuf, stream}; @@ -278,7 +163,7 @@ mod tests { let len = prop.wire_len(); // determine length of variable integer - let len_of_wire_len = write_variable_integer(&mut buf, len).unwrap(); + let len_of_wire_len = len.write_variable_integer(&mut buf).unwrap(); // clear buffer before writing actual properties buf.clear(); prop.write(&mut buf).unwrap(); @@ -404,7 +289,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -435,7 +320,7 @@ mod tests { "Another thingy".write(&mut properties).unwrap(); "The thingy".write(&mut properties).unwrap(); - write_variable_integer(&mut buf, properties.len()).unwrap(); + properties.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties); @@ -464,7 +349,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let properties = PubRelProperties::read(&mut buf.clone().into()).unwrap(); @@ -487,7 +372,7 @@ mod tests { "The thingy".write(&mut properties_data).unwrap(); let mut buf = BytesMut::new(); - write_variable_integer(&mut buf, properties_data.len()).unwrap(); + properties_data.len().write_variable_integer(&mut buf).unwrap(); buf.extend(properties_data); let (properties, read_bytes) = PubRelProperties::async_read(&mut &*buf).await.unwrap(); diff --git a/mqrstt/src/packets/pubrel/properties.rs b/mqrstt/src/packets/pubrel/properties.rs new file mode 100644 index 0000000..6ccad2a --- /dev/null +++ b/mqrstt/src/packets/pubrel/properties.rs @@ -0,0 +1,122 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger +}; + +crate::packets::macros::define_properties!(PubRelProperties, + ReasonString, + UserProperty +); + +impl PubRelProperties { + pub fn is_empty(&self) -> bool { + self.reason_string.is_none() && self.user_properties.is_empty() + } +} + +impl MqttRead for PubRelProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + if len == 0 { + return Ok(Self::default()); + } + if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties = PubRelProperties::default(); + + loop { + match PropertyType::try_from(u8::read(buf)?)? { + PropertyType::ReasonString => { + if properties.reason_string.is_some() { + return Err(DeserializeError::DuplicateProperty(PropertyType::ReasonString)); + } + properties.reason_string = Some(Box::::read(buf)?); + } + PropertyType::UserProperty => properties.user_properties.push((Box::::read(buf)?, Box::::read(buf)?)), + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::PubRel)), + } + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +// impl MqttAsyncRead for PubRelProperties where S: tokio::io::AsyncReadExt + Unpin { +// async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { +// let (len, length_variable_integer) = read_async_variable_integer(stream).await?; +// if len == 0 { +// return Ok((Self::default(), length_variable_integer)); +// } + +// let mut properties = PubRelProperties::default(); + +// let mut read_property_bytes = 0; +// loop { +// let (prop, read_bytes) = PropertyType::async_read(stream).await?; +// read_property_bytes += read_bytes; +// match prop { +// PropertyType::ReasonString => { +// if properties.reason_string.is_some() { +// return Err(super::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); +// } +// let (prop_body, read_bytes) = Box::::async_read(stream).await?; +// read_property_bytes += read_bytes; +// properties.reason_string = Some(prop_body); +// } +// PropertyType::UserProperty => { +// let (prop_body_key, read_bytes) = Box::::async_read(stream).await?; +// read_property_bytes += read_bytes; +// let (prop_body_value, read_bytes) = Box::::async_read(stream).await?; +// read_property_bytes += read_bytes; + +// properties.user_properties.push((prop_body_key, prop_body_value)) +// }, +// e => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), +// } +// if read_property_bytes == len { +// break; +// } +// } + +// Ok((properties, length_variable_integer + read_property_bytes)) +// } +// } + +impl MqttWrite for PubRelProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + let len = self.wire_len(); + + len.write_variable_integer(buf)?; + + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)? + } + + Ok(()) + } +} + +// impl WireLength for PubRelProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(reason_string) = &self.reason_string { +// len += reason_string.wire_len() + 1; +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } + +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrel/reason_code.rs b/mqrstt/src/packets/pubrel/reason_code.rs new file mode 100644 index 0000000..275410b --- /dev/null +++ b/mqrstt/src/packets/pubrel/reason_code.rs @@ -0,0 +1,49 @@ +crate::packets::macros::reason_code!( + PubRelReasonCode, + Success, + PacketIdentifierNotFound +); + + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum PubRelReasonCode { +// Success, +// PacketIdentifierNotFound, +// } + +// impl MqttRead for PubRelReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(PubRelReasonCode::Success), +// 0x92 => Ok(PubRelReasonCode::PacketIdentifierNotFound), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttAsyncRead for PubRelReasonCode where S: tokio::io::AsyncReadExt + Unpin { +// async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { +// let code = match stream.read_u8().await? { +// 0x00 => PubRelReasonCode::Success, +// 0x92 => PubRelReasonCode::PacketIdentifierNotFound, +// t => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnknownProperty(t))), +// }; +// Ok((code, 1)) +// } +// } + +// impl MqttWrite for PubRelReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// PubRelReasonCode::Success => 0x00, +// PubRelReasonCode::PacketIdentifierNotFound => 0x92, +// }; + +// buf.put_u8(val); +// Ok(()) +// } +// } diff --git a/mqrstt/src/packets/reason_codes.rs b/mqrstt/src/packets/reason_codes.rs index 90bc37a..f86a8a3 100644 --- a/mqrstt/src/packets/reason_codes.rs +++ b/mqrstt/src/packets/reason_codes.rs @@ -3,34 +3,7 @@ use std::default; use bytes::{Buf, BufMut}; use super::error::DeserializeError; -use super::mqtt_traits::{MqttAsyncRead, MqttRead, MqttWrite}; - - - -super::macros::reason_code!(ConnAckReasonCode, - Success, - UnspecifiedError, - MalformedPacket, - ProtocolError, - ImplementationSpecificError, - UnsupportedProtocolVersion, - ClientIdentifierNotValid, - BadUsernameOrPassword, - NotAuthorized, - ServerUnavailable, - ServerBusy, - Banned, - BadAuthenticationMethod, - TopicNameInvalid, - PacketTooLarge, - QuotaExceeded, - PayloadFormatInvalid, - RetainNotSupported, - QosNotSupported, - UseAnotherServer, - ServerMoved, - ConnectionRateExceeded -); +use super::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}; // #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -153,12 +126,6 @@ super::macros::reason_code!(ConnAckReasonCode, // } // } -super::macros::reason_code!(AuthReasonCode, - Success, - ContinueAuthentication, - ReAuthenticate -); - // #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] // pub enum AuthReasonCode { // Success, @@ -195,442 +162,97 @@ super::macros::reason_code!(AuthReasonCode, // } // } -super::macros::reason_code!(DisconnectReasonCode, - NormalDisconnection, - DisconnectWithWillMessage, - UnspecifiedError, - MalformedPacket, - ProtocolError, - ImplementationSpecificError, - NotAuthorized, - ServerBusy, - ServerShuttingDown, - KeepAliveTimeout, - SessionTakenOver, - TopicFilterInvalid, - TopicNameInvalid, - ReceiveMaximumExceeded, - TopicAliasInvalid, - PacketTooLarge, - MessageRateTooHigh, - QuotaExceeded, - AdministrativeAction, - PayloadFormatInvalid, - RetainNotSupported, - QosNotSupported, - UseAnotherServer, - ServerMoved, - SharedSubscriptionsNotSupported, - ConnectionRateExceeded, - MaximumConnectTime, - SubscriptionIdentifiersNotSupported, - WildcardSubscriptionsNotSupported -); -// #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum DisconnectReasonCode { -// #[default] -// NormalDisconnection, -// DisconnectWithWillMessage, + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum PubCompReasonCode { +// Success, +// PacketIdentifierNotFound, +// } + +// impl MqttRead for PubCompReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(PubCompReasonCode::Success), +// 0x92 => Ok(PubCompReasonCode::PacketIdentifierNotFound), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } +// impl MqttWrite for PubCompReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { +// let val = match self { +// PubCompReasonCode::Success => 0x00, +// PubCompReasonCode::PacketIdentifierNotFound => 0x92, +// }; + +// buf.put_u8(val); +// Ok(()) +// } +// } + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum SubAckReasonCode { +// GrantedQoS0, +// GrantedQoS1, +// GrantedQoS2, // UnspecifiedError, -// MalformedPacket, -// ProtocolError, // ImplementationSpecificError, // NotAuthorized, -// ServerBusy, -// ServerShuttingDown, -// KeepAliveTimeout, -// SessionTakenOver, // TopicFilterInvalid, -// TopicNameInvalid, -// ReceiveMaximumExceeded, -// TopicAliasInvalid, -// PacketTooLarge, -// MessageRateTooHigh, +// PacketIdentifierInUse, // QuotaExceeded, -// AdministrativeAction, -// PayloadFormatInvalid, -// RetainNotSupported, -// QosNotSupported, -// UseAnotherServer, -// ServerMoved, // SharedSubscriptionsNotSupported, -// ConnectionRateExceeded, -// MaximumConnectTime, // SubscriptionIdentifiersNotSupported, // WildcardSubscriptionsNotSupported, // } -// impl MqttRead for DisconnectReasonCode { +// impl MqttRead for SubAckReasonCode { // fn read(buf: &mut bytes::Bytes) -> Result { // if buf.is_empty() { // return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); // } // match buf.get_u8() { -// 0x00 => Ok(DisconnectReasonCode::NormalDisconnection), -// 0x04 => Ok(DisconnectReasonCode::DisconnectWithWillMessage), -// 0x80 => Ok(DisconnectReasonCode::UnspecifiedError), -// 0x81 => Ok(DisconnectReasonCode::MalformedPacket), -// 0x82 => Ok(DisconnectReasonCode::ProtocolError), -// 0x83 => Ok(DisconnectReasonCode::ImplementationSpecificError), -// 0x87 => Ok(DisconnectReasonCode::NotAuthorized), -// 0x89 => Ok(DisconnectReasonCode::ServerBusy), -// 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), -// 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), -// 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), -// 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), -// 0x90 => Ok(DisconnectReasonCode::TopicNameInvalid), -// 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), -// 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), -// 0x95 => Ok(DisconnectReasonCode::PacketTooLarge), -// 0x96 => Ok(DisconnectReasonCode::MessageRateTooHigh), -// 0x97 => Ok(DisconnectReasonCode::QuotaExceeded), -// 0x98 => Ok(DisconnectReasonCode::AdministrativeAction), -// 0x99 => Ok(DisconnectReasonCode::PayloadFormatInvalid), -// 0x9A => Ok(DisconnectReasonCode::RetainNotSupported), -// 0x9B => Ok(DisconnectReasonCode::QosNotSupported), -// 0x9C => Ok(DisconnectReasonCode::UseAnotherServer), -// 0x9D => Ok(DisconnectReasonCode::ServerMoved), -// 0x9E => Ok(DisconnectReasonCode::SharedSubscriptionsNotSupported), -// 0x9F => Ok(DisconnectReasonCode::ConnectionRateExceeded), -// 0xA0 => Ok(DisconnectReasonCode::MaximumConnectTime), -// 0xA1 => Ok(DisconnectReasonCode::SubscriptionIdentifiersNotSupported), -// 0xA2 => Ok(DisconnectReasonCode::WildcardSubscriptionsNotSupported), +// 0x00 => Ok(SubAckReasonCode::GrantedQoS0), +// 0x01 => Ok(SubAckReasonCode::GrantedQoS1), +// 0x02 => Ok(SubAckReasonCode::GrantedQoS2), +// 0x80 => Ok(SubAckReasonCode::UnspecifiedError), +// 0x83 => Ok(SubAckReasonCode::ImplementationSpecificError), +// 0x87 => Ok(SubAckReasonCode::NotAuthorized), +// 0x8F => Ok(SubAckReasonCode::TopicFilterInvalid), +// 0x91 => Ok(SubAckReasonCode::PacketIdentifierInUse), +// 0x97 => Ok(SubAckReasonCode::QuotaExceeded), +// 0x9E => Ok(SubAckReasonCode::SharedSubscriptionsNotSupported), +// 0xA1 => Ok(SubAckReasonCode::SubscriptionIdentifiersNotSupported), +// 0xA2 => Ok(SubAckReasonCode::WildcardSubscriptionsNotSupported), // t => Err(DeserializeError::UnknownProperty(t)), // } // } // } -// impl MqttWrite for DisconnectReasonCode { +// impl MqttWrite for SubAckReasonCode { // fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { // let val = match self { -// DisconnectReasonCode::NormalDisconnection => 0x00, -// DisconnectReasonCode::DisconnectWithWillMessage => 0x04, -// DisconnectReasonCode::UnspecifiedError => 0x80, -// DisconnectReasonCode::MalformedPacket => 0x81, -// DisconnectReasonCode::ProtocolError => 0x82, -// DisconnectReasonCode::ImplementationSpecificError => 0x83, -// DisconnectReasonCode::NotAuthorized => 0x87, -// DisconnectReasonCode::ServerBusy => 0x89, -// DisconnectReasonCode::ServerShuttingDown => 0x8B, - -// DisconnectReasonCode::KeepAliveTimeout => 0x8D, -// DisconnectReasonCode::SessionTakenOver => 0x8E, -// DisconnectReasonCode::TopicFilterInvalid => 0x8F, - -// DisconnectReasonCode::TopicNameInvalid => 0x90, -// DisconnectReasonCode::ReceiveMaximumExceeded => 0x93, -// DisconnectReasonCode::TopicAliasInvalid => 0x94, -// DisconnectReasonCode::PacketTooLarge => 0x95, -// DisconnectReasonCode::MessageRateTooHigh => 0x96, -// DisconnectReasonCode::QuotaExceeded => 0x97, -// DisconnectReasonCode::AdministrativeAction => 0x98, -// DisconnectReasonCode::PayloadFormatInvalid => 0x99, -// DisconnectReasonCode::RetainNotSupported => 0x9A, -// DisconnectReasonCode::QosNotSupported => 0x9B, -// DisconnectReasonCode::UseAnotherServer => 0x9C, -// DisconnectReasonCode::ServerMoved => 0x9D, -// DisconnectReasonCode::SharedSubscriptionsNotSupported => 0x9E, -// DisconnectReasonCode::ConnectionRateExceeded => 0x9F, -// DisconnectReasonCode::MaximumConnectTime => 0xA0, -// DisconnectReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, -// DisconnectReasonCode::WildcardSubscriptionsNotSupported => 0xA2, +// SubAckReasonCode::GrantedQoS0 => 0x00, +// SubAckReasonCode::GrantedQoS1 => 0x01, +// SubAckReasonCode::GrantedQoS2 => 0x02, +// SubAckReasonCode::UnspecifiedError => 0x80, +// SubAckReasonCode::ImplementationSpecificError => 0x83, +// SubAckReasonCode::NotAuthorized => 0x87, +// SubAckReasonCode::TopicFilterInvalid => 0x8F, +// SubAckReasonCode::PacketIdentifierInUse => 0x91, +// SubAckReasonCode::QuotaExceeded => 0x97, +// SubAckReasonCode::SharedSubscriptionsNotSupported => 0x9E, +// SubAckReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, +// SubAckReasonCode::WildcardSubscriptionsNotSupported => 0xA2, // }; // buf.put_u8(val); - // Ok(()) // } -// } - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubAckReasonCode { - Success, - NoMatchingSubscribers, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicNameInvalid, - PacketIdentifierInUse, - QuotaExceeded, - PayloadFormatInvalid, -} - -impl MqttRead for PubAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubAckReasonCode::Success), - 0x10 => Ok(PubAckReasonCode::NoMatchingSubscribers), - 0x80 => Ok(PubAckReasonCode::UnspecifiedError), - 0x83 => Ok(PubAckReasonCode::ImplementationSpecificError), - 0x87 => Ok(PubAckReasonCode::NotAuthorized), - 0x90 => Ok(PubAckReasonCode::TopicNameInvalid), - 0x91 => Ok(PubAckReasonCode::PacketIdentifierInUse), - 0x97 => Ok(PubAckReasonCode::QuotaExceeded), - 0x99 => Ok(PubAckReasonCode::PayloadFormatInvalid), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PubAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubAckReasonCode::Success => 0x00, - PubAckReasonCode::NoMatchingSubscribers => 0x10, - PubAckReasonCode::UnspecifiedError => 0x80, - PubAckReasonCode::ImplementationSpecificError => 0x83, - PubAckReasonCode::NotAuthorized => 0x87, - PubAckReasonCode::TopicNameInvalid => 0x90, - PubAckReasonCode::PacketIdentifierInUse => 0x91, - PubAckReasonCode::QuotaExceeded => 0x97, - PubAckReasonCode::PayloadFormatInvalid => 0x99, - }; - - buf.put_u8(val); - - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubCompReasonCode { - Success, - PacketIdentifierNotFound, -} - -impl MqttRead for PubCompReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubCompReasonCode::Success), - 0x92 => Ok(PubCompReasonCode::PacketIdentifierNotFound), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PubCompReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubCompReasonCode::Success => 0x00, - PubCompReasonCode::PacketIdentifierNotFound => 0x92, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubRecReasonCode { - Success, - NoMatchingSubscribers, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicNameInvalid, - PacketIdentifierInUse, - QuotaExceeded, - PayloadFormatInvalid, -} - -impl MqttRead for PubRecReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubRecReasonCode::Success), - 0x10 => Ok(PubRecReasonCode::NoMatchingSubscribers), - 0x80 => Ok(PubRecReasonCode::UnspecifiedError), - 0x83 => Ok(PubRecReasonCode::ImplementationSpecificError), - 0x87 => Ok(PubRecReasonCode::NotAuthorized), - 0x90 => Ok(PubRecReasonCode::TopicNameInvalid), - 0x91 => Ok(PubRecReasonCode::PacketIdentifierInUse), - 0x97 => Ok(PubRecReasonCode::QuotaExceeded), - 0x99 => Ok(PubRecReasonCode::PayloadFormatInvalid), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for PubRecReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubRecReasonCode::Success => 0x00, - PubRecReasonCode::NoMatchingSubscribers => 0x10, - PubRecReasonCode::UnspecifiedError => 0x80, - PubRecReasonCode::ImplementationSpecificError => 0x83, - PubRecReasonCode::NotAuthorized => 0x87, - PubRecReasonCode::TopicNameInvalid => 0x90, - PubRecReasonCode::PacketIdentifierInUse => 0x91, - PubRecReasonCode::QuotaExceeded => 0x97, - PubRecReasonCode::PayloadFormatInvalid => 0x99, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum PubRelReasonCode { - Success, - PacketIdentifierNotFound, -} - -impl MqttRead for PubRelReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(PubRelReasonCode::Success), - 0x92 => Ok(PubRelReasonCode::PacketIdentifierNotFound), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttAsyncRead for PubRelReasonCode where S: tokio::io::AsyncReadExt + Unpin { - async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { - let code = match stream.read_u8().await? { - 0x00 => PubRelReasonCode::Success, - 0x92 => PubRelReasonCode::PacketIdentifierNotFound, - t => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnknownProperty(t))), - }; - Ok((code, 1)) - } -} - -impl MqttWrite for PubRelReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - PubRelReasonCode::Success => 0x00, - PubRelReasonCode::PacketIdentifierNotFound => 0x92, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum SubAckReasonCode { - GrantedQoS0, - GrantedQoS1, - GrantedQoS2, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicFilterInvalid, - PacketIdentifierInUse, - QuotaExceeded, - SharedSubscriptionsNotSupported, - SubscriptionIdentifiersNotSupported, - WildcardSubscriptionsNotSupported, -} - -impl MqttRead for SubAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(SubAckReasonCode::GrantedQoS0), - 0x01 => Ok(SubAckReasonCode::GrantedQoS1), - 0x02 => Ok(SubAckReasonCode::GrantedQoS2), - 0x80 => Ok(SubAckReasonCode::UnspecifiedError), - 0x83 => Ok(SubAckReasonCode::ImplementationSpecificError), - 0x87 => Ok(SubAckReasonCode::NotAuthorized), - 0x8F => Ok(SubAckReasonCode::TopicFilterInvalid), - 0x91 => Ok(SubAckReasonCode::PacketIdentifierInUse), - 0x97 => Ok(SubAckReasonCode::QuotaExceeded), - 0x9E => Ok(SubAckReasonCode::SharedSubscriptionsNotSupported), - 0xA1 => Ok(SubAckReasonCode::SubscriptionIdentifiersNotSupported), - 0xA2 => Ok(SubAckReasonCode::WildcardSubscriptionsNotSupported), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for SubAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - SubAckReasonCode::GrantedQoS0 => 0x00, - SubAckReasonCode::GrantedQoS1 => 0x01, - SubAckReasonCode::GrantedQoS2 => 0x02, - SubAckReasonCode::UnspecifiedError => 0x80, - SubAckReasonCode::ImplementationSpecificError => 0x83, - SubAckReasonCode::NotAuthorized => 0x87, - SubAckReasonCode::TopicFilterInvalid => 0x8F, - SubAckReasonCode::PacketIdentifierInUse => 0x91, - SubAckReasonCode::QuotaExceeded => 0x97, - SubAckReasonCode::SharedSubscriptionsNotSupported => 0x9E, - SubAckReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, - SubAckReasonCode::WildcardSubscriptionsNotSupported => 0xA2, - }; - - buf.put_u8(val); - Ok(()) - } -} - -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -pub enum UnsubAckReasonCode { - Success, - NoSubscriptionExisted, - UnspecifiedError, - ImplementationSpecificError, - NotAuthorized, - TopicFilterInvalid, - PacketIdentifierInUse, -} - -impl MqttRead for UnsubAckReasonCode { - fn read(buf: &mut bytes::Bytes) -> Result { - if buf.is_empty() { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); - } - - match buf.get_u8() { - 0x00 => Ok(UnsubAckReasonCode::Success), - 0x11 => Ok(UnsubAckReasonCode::NoSubscriptionExisted), - 0x80 => Ok(UnsubAckReasonCode::UnspecifiedError), - 0x83 => Ok(UnsubAckReasonCode::ImplementationSpecificError), - 0x87 => Ok(UnsubAckReasonCode::NotAuthorized), - 0x8F => Ok(UnsubAckReasonCode::TopicFilterInvalid), - 0x91 => Ok(UnsubAckReasonCode::PacketIdentifierInUse), - t => Err(DeserializeError::UnknownProperty(t)), - } - } -} - -impl MqttWrite for UnsubAckReasonCode { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - let val = match self { - UnsubAckReasonCode::Success => 0x00, - UnsubAckReasonCode::NoSubscriptionExisted => 0x11, - UnsubAckReasonCode::UnspecifiedError => 0x80, - UnsubAckReasonCode::ImplementationSpecificError => 0x83, - UnsubAckReasonCode::NotAuthorized => 0x87, - UnsubAckReasonCode::TopicFilterInvalid => 0x8F, - UnsubAckReasonCode::PacketIdentifierInUse => 0x91, - }; - - buf.put_u8(val); - Ok(()) - } -} +// } \ No newline at end of file diff --git a/mqrstt/src/packets/suback.rs b/mqrstt/src/packets/suback.rs deleted file mode 100644 index 95718f2..0000000 --- a/mqrstt/src/packets/suback.rs +++ /dev/null @@ -1,160 +0,0 @@ -use bytes::BufMut; - -use super::{ - error::{DeserializeError, SerializeError}, - mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - read_variable_integer, - reason_codes::SubAckReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; - -///3.9 SUBACK – Subscribe acknowledgement -/// A SUBACK packet is sent by the Server to the Client to confirm receipt and processing of a SUBSCRIBE packet. -/// A SUBACK packet contains a list of Reason Codes, that specify the maximum QoS level that was granted or the error which was found for each Subscription that was requested by the SUBSCRIBE. -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct SubAck { - pub packet_identifier: u16, - pub properties: SubAckProperties, - pub reason_codes: Vec, -} - -impl PacketRead for SubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { - let packet_identifier = u16::read(&mut buf)?; - let properties = SubAckProperties::read(&mut buf)?; - let mut reason_codes = vec![]; - loop { - let reason_code = SubAckReasonCode::read(&mut buf)?; - - reason_codes.push(reason_code); - - if buf.is_empty() { - break; - } - } - - Ok(Self { - packet_identifier, - properties, - reason_codes, - }) - } -} - -impl PacketWrite for SubAck { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.packet_identifier); - - self.properties.write(buf)?; - for reason_code in &self.reason_codes { - reason_code.write(buf)?; - } - - Ok(()) - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct SubAckProperties { - /// 3.8.2.1.2 Subscription Identifier - /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. - pub subscription_id: Option, - - /// 3.8.2.1.3 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for SubAckProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = SubAckProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::SubscriptionIdentifier => { - if properties.subscription_id.is_none() { - let (subscription_id, _) = read_variable_integer(&mut properties_data)?; - - properties.subscription_id = Some(subscription_id); - } else { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); - } - } - PropertyType::UserProperty => { - properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::SubAck)), - } - - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for SubAckProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - if let Some(sub_id) = self.subscription_id { - PropertyType::SubscriptionIdentifier.write(buf)?; - write_variable_integer(buf, sub_id)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for SubAckProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(sub_id) = self.subscription_id { - len += 1 + variable_integer_len(sub_id); - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - -#[cfg(test)] -mod test { - use bytes::BytesMut; - - use super::SubAck; - use crate::packets::mqtt_traits::{PacketRead, PacketWrite}; - - #[test] - fn read_write_suback() { - let buf = vec![ - 0x00, 0x0F, // variable header. pkid = 15 - 0x00, // Property length 0 - 0x01, // Payload reason code codes Granted QoS 1, - 0x80, // Payload Unspecified error - ]; - - let data = BytesMut::from(&buf[..]); - let sub_ack = SubAck::read(0, 0, data.clone().into()).unwrap(); - - let mut result = BytesMut::new(); - sub_ack.write(&mut result).unwrap(); - - assert_eq!(data.to_vec(), result.to_vec()); - } -} diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs new file mode 100644 index 0000000..2702066 --- /dev/null +++ b/mqrstt/src/packets/suback/mod.rs @@ -0,0 +1,116 @@ +mod properties; +pub use properties::SubAckProperties; + +mod reason_code; +pub use reason_code::SubAckReasonCode; + +use bytes::BufMut; + +use super::{ + error::{SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + +}; + +/// 3.9 SUBACK – Subscribe acknowledgement +/// A SUBACK packet is sent by the Server to the Client to confirm receipt and processing of a SUBSCRIBE packet. +/// A SUBACK packet contains a list of Reason Codes, that specify the maximum QoS level that was granted or the error which was found for each Subscription that was requested by the SUBSCRIBE. +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct SubAck { + pub packet_identifier: u16, + pub properties: SubAckProperties, + pub reason_codes: Vec, +} + +impl PacketRead for SubAck { + fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { + let packet_identifier = u16::read(&mut buf)?; + let properties = SubAckProperties::read(&mut buf)?; + + dbg!("aa"); + + let mut reason_codes = vec![]; + loop { + let reason_code = SubAckReasonCode::read(&mut buf)?; + + dbg!(reason_code); + reason_codes.push(reason_code); + + if buf.is_empty() { + break; + } + } + + Ok(Self { + packet_identifier, + properties, + reason_codes, + }) + } +} + +impl PacketAsyncRead for SubAck where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let mut total_read_bytes = 0; + let (packet_identifier, id_read_bytes) = u16::async_read(stream).await?; + let (properties, proproperties_read_bytes) = SubAckProperties::async_read(stream).await?; + total_read_bytes += id_read_bytes + proproperties_read_bytes; + let mut reason_codes = vec![]; + loop { + let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + reason_codes.push(reason_code); + + if remaining_length == total_read_bytes { + break; + } + } + + Ok((Self { + packet_identifier, + properties, + reason_codes, + }, total_read_bytes)) + } + } +} + +impl PacketWrite for SubAck { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.packet_identifier); + + self.properties.write(buf)?; + for reason_code in &self.reason_codes { + reason_code.write(buf)?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use bytes::BytesMut; + + use super::SubAck; + use crate::packets::mqtt_trait::{PacketRead, PacketWrite}; + + #[test] + fn read_write_suback() { + let buf = vec![ + 0x00, 0x0F, // variable header. pkid = 15 + 0x00, // Property length 0 + 0x01, // Payload reason code codes Granted QoS 1, + 0x80, // Payload Unspecified error + ]; + + let data = BytesMut::from(&buf[..]); + let sub_ack = SubAck::read(0, 0, data.clone().into()).unwrap(); + + let mut result = BytesMut::new(); + sub_ack.write(&mut result).unwrap(); + + assert_eq!(data.to_vec(), result.to_vec()); + } +} diff --git a/mqrstt/src/packets/suback/properties.rs b/mqrstt/src/packets/suback/properties.rs new file mode 100644 index 0000000..8a06e7e --- /dev/null +++ b/mqrstt/src/packets/suback/properties.rs @@ -0,0 +1,92 @@ +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, +}; + +use crate::packets::primitive::VariableInteger; + +crate::packets::macros::define_properties!(SubAckProperties, + SubscriptionIdentifier, + UserProperty +); + +// #[derive(Debug, Default, PartialEq, Eq, Clone)] +// pub struct SubAckProperties { +// /// 3.8.2.1.2 Subscription Identifier +// /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. +// pub subscription_id: Option, + +// /// 3.8.2.1.3 User Property +// /// 38 (0x26) Byte, Identifier of the User Property. +// pub user_properties: Vec<(Box, Box)>, +// } + +impl MqttRead for SubAckProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = SubAckProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::SubscriptionIdentifier => { + if properties.subscription_identifier.is_none() { + let (subscription_id, _) = VariableInteger::read_variable_integer(&mut properties_data)?; + + properties.subscription_identifier = Some(subscription_id); + } else { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); + } + } + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::SubAck)), + } + + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for SubAckProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + if let Some(sub_id) = self.subscription_identifier { + PropertyType::SubscriptionIdentifier.write(buf)?; + sub_id.write_variable_integer(buf)?; + // sub_id.write_variable_integer(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} + +// impl WireLength for SubAckProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(sub_id) = self.subscription_id { +// len += 1 + sub_id.variable_integer_len(); +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } +// len +// } +// } diff --git a/mqrstt/src/packets/suback/reason_code.rs b/mqrstt/src/packets/suback/reason_code.rs new file mode 100644 index 0000000..3185a87 --- /dev/null +++ b/mqrstt/src/packets/suback/reason_code.rs @@ -0,0 +1,61 @@ + + +crate::packets::macros::reason_code!( + SubAckReasonCode, + GrantedQoS0, + GrantedQoS1, + GrantedQoS2, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicFilterInvalid, + PacketIdentifierInUse, + QuotaExceeded, + SharedSubscriptionsNotSupported, + SubscriptionIdentifiersNotSupported, + WildcardSubscriptionsNotSupported +); + + + + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum SubAckReasonCode { +// GrantedQoS0, +// GrantedQoS1, +// GrantedQoS2, +// ImplementationSpecificError, +// NotAuthorized, +// TopicFilterInvalid, +// PacketIdentifierInUse, +// QuotaExceeded, +// SharedSubscriptionsNotSupported, +// SubscriptionIdentifiersNotSupported, +// WildcardSubscriptionsNotSupported, +// } + +// impl MqttRead for SubAckReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(SubAckReasonCode::GrantedQoS0), +// 0x01 => Ok(SubAckReasonCode::GrantedQoS1), +// 0x02 => Ok(SubAckReasonCode::GrantedQoS2), +// 0x80 => Ok(SubAckReasonCode::UnspecifiedError), +// 0x83 => Ok(SubAckReasonCode::ImplementationSpecificError), +// 0x87 => Ok(SubAckReasonCode::NotAuthorized), +// 0x8F => Ok(SubAckReasonCode::TopicFilterInvalid), +// 0x91 => Ok(SubAckReasonCode::PacketIdentifierInUse), +// 0x97 => Ok(SubAckReasonCode::QuotaExceeded), +// 0x9E => Ok(SubAckReasonCode::SharedSubscriptionsNotSupported), +// 0xA1 => Ok(SubAckReasonCode::SubscriptionIdentifiersNotSupported), +// 0xA2 => Ok(SubAckReasonCode::WildcardSubscriptionsNotSupported), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttWrite for SubAckReasonCode { \ No newline at end of file diff --git a/mqrstt/src/packets/subscribe.rs b/mqrstt/src/packets/subscribe.rs index 0142ea5..3263205 100644 --- a/mqrstt/src/packets/subscribe.rs +++ b/mqrstt/src/packets/subscribe.rs @@ -2,8 +2,8 @@ use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, QoS, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketValidation, PacketWrite, WireLength}, + PacketType, PropertyType, QoS, VariableInteger, }; use bytes::{Buf, BufMut}; @@ -66,7 +66,7 @@ impl WireLength for Subscribe { fn wire_len(&self) -> usize { let mut len = 2; let properties_len = self.properties.wire_len(); - len += properties_len + variable_integer_len(properties_len); + len += properties_len + properties_len.variable_integer_len(); for topic in &self.topics { len += topic.0.wire_len() + 1; } @@ -101,7 +101,7 @@ pub struct SubscribeProperties { impl MqttRead for SubscribeProperties { fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; + let (len, _) = VariableInteger::read_variable_integer(buf)?; let mut properties = SubscribeProperties::default(); @@ -117,7 +117,7 @@ impl MqttRead for SubscribeProperties { match PropertyType::read(&mut properties_data)? { PropertyType::SubscriptionIdentifier => { if properties.subscription_id.is_none() { - let (subscription_id, _) = read_variable_integer(&mut properties_data)?; + let (subscription_id, _) = VariableInteger::read_variable_integer(&mut properties_data)?; properties.subscription_id = Some(subscription_id); } else { @@ -140,10 +140,10 @@ impl MqttRead for SubscribeProperties { impl MqttWrite for SubscribeProperties { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; + self.wire_len().write_variable_integer(buf)?; if let Some(sub_id) = self.subscription_id { PropertyType::SubscriptionIdentifier.write(buf)?; - write_variable_integer(buf, sub_id)?; + sub_id.write_variable_integer(buf)?; } for (key, value) in &self.user_properties { PropertyType::UserProperty.write(buf)?; @@ -158,7 +158,7 @@ impl WireLength for SubscribeProperties { fn wire_len(&self) -> usize { let mut len = 0; if let Some(sub_id) = self.subscription_id { - len += 1 + variable_integer_len(sub_id); + len += 1 + sub_id.variable_integer_len(); } for (key, value) in &self.user_properties { len += 1 + key.wire_len() + value.wire_len(); @@ -390,7 +390,7 @@ mod tests { use bytes::{Bytes, BytesMut}; use crate::packets::{ - mqtt_traits::{MqttRead, PacketRead, PacketWrite}, + mqtt_trait::{MqttRead, PacketRead, PacketWrite}, Packet, }; diff --git a/mqrstt/src/packets/subscribe/properties.rs b/mqrstt/src/packets/subscribe/properties.rs new file mode 100644 index 0000000..e69de29 diff --git a/mqrstt/src/packets/subscribe/reason_code.rs b/mqrstt/src/packets/subscribe/reason_code.rs new file mode 100644 index 0000000..e69de29 diff --git a/mqrstt/src/packets/unsuback.rs b/mqrstt/src/packets/unsuback.rs deleted file mode 100644 index f7c556b..0000000 --- a/mqrstt/src/packets/unsuback.rs +++ /dev/null @@ -1,149 +0,0 @@ -use bytes::BufMut; - -use super::error::{DeserializeError, SerializeError}; -use super::mqtt_traits::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}; -use super::{read_variable_integer, reason_codes::UnsubAckReasonCode, write_variable_integer, PacketType, PropertyType}; - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct UnsubAck { - pub packet_identifier: u16, - pub properties: UnsubAckProperties, - pub reason_codes: Vec, -} - -impl PacketRead for UnsubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { - let packet_identifier = u16::read(&mut buf)?; - let properties = UnsubAckProperties::read(&mut buf)?; - let mut reason_codes = vec![]; - loop { - let reason_code = UnsubAckReasonCode::read(&mut buf)?; - - reason_codes.push(reason_code); - - if buf.is_empty() { - break; - } - } - - Ok(Self { - packet_identifier, - properties, - reason_codes, - }) - } -} - -impl PacketWrite for UnsubAck { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { - buf.put_u16(self.packet_identifier); - self.properties.write(buf)?; - for reason_code in &self.reason_codes { - reason_code.write(buf)?; - } - Ok(()) - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct UnsubAckProperties { - /// 3.11.2.1.2 Reason String - /// 31 (0x1F) Byte, Identifier of the Reason String. - pub reason_string: Option>, - - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for UnsubAckProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = UnsubAckProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::ReasonString => { - if properties.reason_string.is_none() { - properties.reason_string = Some(Box::::read(&mut properties_data)?); - } else { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); - } - } - PropertyType::UserProperty => { - properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::UnsubAck)), - } - - if buf.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for UnsubAckProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for UnsubAckProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(reason_string) = &self.reason_string { - len += 1 + reason_string.wire_len(); - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - -#[cfg(test)] -mod tests { - use bytes::{Bytes, BytesMut}; - - use crate::packets::{ - mqtt_traits::{PacketRead, PacketWrite}, - unsuback::UnsubAck, - }; - - #[test] - fn read_write_unsub_ack() { - // let entire_mqtt_packet = [0xb0, 0x04, 0x35, 0xd7, 0x00, 0x00]; - let unsub_ack = [0x35, 0xd7, 0x00, 0x00]; - - let mut bufmut = BytesMut::new(); - bufmut.extend(&unsub_ack[..]); - - let buf: Bytes = bufmut.into(); - - let s = UnsubAck::read(0xb0, 4, buf.clone()).unwrap(); - - let mut result = BytesMut::new(); - s.write(&mut result).unwrap(); - - assert_eq!(buf.to_vec(), result.to_vec()); - } -} diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs new file mode 100644 index 0000000..6e3bb25 --- /dev/null +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -0,0 +1,80 @@ +mod properties; +pub use properties::UnsubAckProperties; + +mod reason_code; +pub use reason_code::UnsubAckReasonCode; + + +use bytes::BufMut; + +use super::error::{SerializeError}; +use super::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite}; + +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct UnsubAck { + pub packet_identifier: u16, + pub properties: UnsubAckProperties, + pub reason_codes: Vec, +} + +impl PacketRead for UnsubAck { + fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { + let packet_identifier = u16::read(&mut buf)?; + let properties = UnsubAckProperties::read(&mut buf)?; + let mut reason_codes = vec![]; + loop { + let reason_code = UnsubAckReasonCode::read(&mut buf)?; + + reason_codes.push(reason_code); + + if buf.is_empty() { + break; + } + } + + Ok(Self { + packet_identifier, + properties, + reason_codes, + }) + } +} + +impl PacketWrite for UnsubAck { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + buf.put_u16(self.packet_identifier); + self.properties.write(buf)?; + for reason_code in &self.reason_codes { + reason_code.write(buf)?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use bytes::{Bytes, BytesMut}; + + use crate::packets::{ + mqtt_trait::{PacketRead, PacketWrite}, + unsuback::UnsubAck, + }; + + #[test] + fn read_write_unsub_ack() { + // let entire_mqtt_packet = [0xb0, 0x04, 0x35, 0xd7, 0x00, 0x00]; + let unsub_ack = [0x35, 0xd7, 0x00, 0x00]; + + let mut bufmut = BytesMut::new(); + bufmut.extend(&unsub_ack[..]); + + let buf: Bytes = bufmut.into(); + + let s = UnsubAck::read(0xb0, 4, buf.clone()).unwrap(); + + let mut result = BytesMut::new(); + s.write(&mut result).unwrap(); + + assert_eq!(buf.to_vec(), result.to_vec()); + } +} diff --git a/mqrstt/src/packets/unsuback/properties.rs b/mqrstt/src/packets/unsuback/properties.rs new file mode 100644 index 0000000..bd90e8e --- /dev/null +++ b/mqrstt/src/packets/unsuback/properties.rs @@ -0,0 +1,82 @@ +use bytes::BufMut; + +use crate::packets::error::{DeserializeError}; +use crate::packets::mqtt_trait::{MqttRead, MqttWrite, WireLength}; +use crate::packets::{PacketType, PropertyType, VariableInteger}; + +crate::packets::macros::define_properties!(UnsubAckProperties, ReasonString, UserProperty); + +// #[derive(Debug, Default, PartialEq, Eq, Clone)] +// pub struct UnsubAckProperties { +// /// 3.11.2.1.2 Reason String +// /// 31 (0x1F) Byte, Identifier of the Reason String. +// pub reason_string: Option>, + +// pub user_properties: Vec<(Box, Box)>, +// } + +impl MqttRead for UnsubAckProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = UnsubAckProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::ReasonString => { + if properties.reason_string.is_none() { + properties.reason_string = Some(Box::::read(&mut properties_data)?); + } else { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); + } + } + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::UnsubAck)), + } + + if buf.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for UnsubAckProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} + +// impl WireLength for UnsubAckProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(reason_string) = &self.reason_string { +// len += 1 + reason_string.wire_len(); +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } +// len +// } +// } diff --git a/mqrstt/src/packets/unsuback/reason_code.rs b/mqrstt/src/packets/unsuback/reason_code.rs new file mode 100644 index 0000000..292655f --- /dev/null +++ b/mqrstt/src/packets/unsuback/reason_code.rs @@ -0,0 +1,58 @@ +crate::packets::macros::reason_code!( + UnsubAckReasonCode, + Success, + NoSubscriptionExisted, + UnspecifiedError, + ImplementationSpecificError, + NotAuthorized, + TopicFilterInvalid, + PacketIdentifierInUse +); + + +// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] +// pub enum UnsubAckReasonCode { +// Success, +// NoSubscriptionExisted, +// UnspecifiedError, +// ImplementationSpecificError, +// NotAuthorized, +// TopicFilterInvalid, +// PacketIdentifierInUse, +// } + +// impl MqttRead for UnsubAckReasonCode { +// fn read(buf: &mut bytes::Bytes) -> Result { +// if buf.is_empty() { +// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); +// } + +// match buf.get_u8() { +// 0x00 => Ok(UnsubAckReasonCode::Success), +// 0x11 => Ok(UnsubAckReasonCode::NoSubscriptionExisted), +// 0x80 => Ok(UnsubAckReasonCode::UnspecifiedError), +// 0x83 => Ok(UnsubAckReasonCode::ImplementationSpecificError), +// 0x87 => Ok(UnsubAckReasonCode::NotAuthorized), +// 0x8F => Ok(UnsubAckReasonCode::TopicFilterInvalid), +// 0x91 => Ok(UnsubAckReasonCode::PacketIdentifierInUse), +// t => Err(DeserializeError::UnknownProperty(t)), +// } +// } +// } + +// impl MqttWrite for UnsubAckReasonCode { +// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { +// let val = match self { +// UnsubAckReasonCode::Success => 0x00, +// UnsubAckReasonCode::NoSubscriptionExisted => 0x11, +// UnsubAckReasonCode::UnspecifiedError => 0x80, +// UnsubAckReasonCode::ImplementationSpecificError => 0x83, +// UnsubAckReasonCode::NotAuthorized => 0x87, +// UnsubAckReasonCode::TopicFilterInvalid => 0x8F, +// UnsubAckReasonCode::PacketIdentifierInUse => 0x91, +// }; + +// buf.put_u8(val); +// Ok(()) +// } +// } diff --git a/mqrstt/src/packets/unsubscribe.rs b/mqrstt/src/packets/unsubscribe/mod.rs similarity index 76% rename from mqrstt/src/packets/unsubscribe.rs rename to mqrstt/src/packets/unsubscribe/mod.rs index af746ee..f6e3bfc 100644 --- a/mqrstt/src/packets/unsubscribe.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -1,13 +1,23 @@ +mod properties; +pub use properties::UnsubscribeProperties; + use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; +use crate::packets::mqtt_trait::MqttAsyncRead; + +use super::VariableInteger; use super::{ error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}, - read_variable_integer, variable_integer_len, write_variable_integer, PacketType, PropertyType, + mqtt_trait::{MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}, + PacketType, PropertyType, }; use bytes::BufMut; #[derive(Debug, Clone, PartialEq, Eq)] +/// Used to unsubscribe from topics. +/// +/// Multiple topics can be unsubscribed from at once. +/// For convenience [`UnsubscribeTopics`] is provided. pub struct Unsubscribe { pub packet_identifier: u16, pub properties: UnsubscribeProperties, @@ -47,6 +57,35 @@ impl PacketRead for Unsubscribe { } } +impl PacketAsyncRead for Unsubscribe where S: tokio::io::AsyncReadExt + Unpin{ + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let mut total_read_bytes = 0; + let (packet_identifier, id_read_bytes) = u16::async_read(stream).await?; + let (properties, properties_read_bytes) = UnsubscribeProperties::async_read(stream).await?; + total_read_bytes += id_read_bytes + properties_read_bytes; + + let mut topics = vec![]; + loop { + let (topic, topic_read_size) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_size; + + topics.push(topic); + + if total_read_bytes >= remaining_length { + break; + } + } + + Ok((Self { + packet_identifier, + properties, + topics, + }, total_read_bytes)) + } + } +} + impl PacketWrite for Unsubscribe { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -62,7 +101,7 @@ impl PacketWrite for Unsubscribe { impl WireLength for Unsubscribe { fn wire_len(&self) -> usize { - let mut len = 2 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len(); + let mut len = 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len(); for topic in &self.topics { len += topic.wire_len(); } @@ -84,63 +123,6 @@ impl PacketValidation for Unsubscribe { } } -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct UnsubscribeProperties { - pub user_properties: Vec<(String, String)>, -} - -impl MqttRead for UnsubscribeProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = UnsubscribeProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::UserProperty => { - properties.user_properties.push((String::read(&mut properties_data)?, String::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Unsubscribe)), - } - - if properties_data.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for UnsubscribeProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for UnsubscribeProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - trait IntoUnsubscribeTopic { fn into(value: Self) -> Box; } @@ -280,7 +262,7 @@ mod tests { use bytes::{Bytes, BytesMut}; - use crate::packets::mqtt_traits::{PacketRead, PacketWrite}; + use crate::packets::mqtt_trait::{PacketRead, PacketWrite}; use super::Unsubscribe; diff --git a/mqrstt/src/packets/unsubscribe/properties.rs b/mqrstt/src/packets/unsubscribe/properties.rs new file mode 100644 index 0000000..beab1d5 --- /dev/null +++ b/mqrstt/src/packets/unsubscribe/properties.rs @@ -0,0 +1,69 @@ +use crate::packets::VariableInteger; +use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}, + PacketType, PropertyType, +}; +use bytes::BufMut; + + +crate::packets::macros::define_properties!(UnsubscribeProperties, UserProperty); + +// #[derive(Debug, Default, PartialEq, Eq, Clone)] +// pub struct UnsubscribeProperties { +// pub user_properties: Vec<(String, String)>, +// } + +impl MqttRead for UnsubscribeProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = UnsubscribeProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Unsubscribe)), + } + + if properties_data.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for UnsubscribeProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} + +// impl WireLength for UnsubscribeProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } +// len +// } +// } diff --git a/mqrstt/src/smol/network.rs b/mqrstt/src/smol/network.rs index 20ef952..29ab074 100644 --- a/mqrstt/src/smol/network.rs +++ b/mqrstt/src/smol/network.rs @@ -9,8 +9,7 @@ use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::ConnectionError; use crate::packets::error::ReadBytes; -use crate::packets::reason_codes::DisconnectReasonCode; -use crate::packets::{Disconnect, Packet, PacketType}; +use crate::packets::{Disconnect, DisconnectReasonCode, Packet, PacketType}; use crate::NetworkStatus; use crate::{AsyncEventHandlerMut, StateHandler}; diff --git a/mqrstt/src/smol/stream.rs b/mqrstt/src/smol/stream.rs index a24c1a6..efa3a8f 100644 --- a/mqrstt/src/smol/stream.rs +++ b/mqrstt/src/smol/stream.rs @@ -10,7 +10,7 @@ use tracing::trace; use crate::packets::ConnAck; use crate::packets::{ error::ReadBytes, - reason_codes::ConnAckReasonCode, + ConnAckReasonCode, {FixedHeader, Packet}, }; use crate::{connect_options::ConnectOptions, error::ConnectionError}; diff --git a/mqrstt/src/state_handler.rs b/mqrstt/src/state_handler.rs index 03ea6fa..d0eeeb9 100644 --- a/mqrstt/src/state_handler.rs +++ b/mqrstt/src/state_handler.rs @@ -1,7 +1,7 @@ use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::HandlerError; -use crate::packets::reason_codes::{ConnAckReasonCode, PubAckReasonCode, PubRecReasonCode}; +use crate::packets::{PubRecReasonCode, PubAckReasonCode, ConnAckReasonCode}; use crate::packets::PubComp; use crate::packets::PubRec; use crate::packets::PubRel; @@ -20,7 +20,7 @@ use crate::state::State; use tracing::{debug, error, info, warn}; /// Eventloop with all the state of a connection -pub struct StateHandler { +pub(crate) struct StateHandler { state: State, clean_start: bool, } @@ -247,8 +247,13 @@ mod handler_tests { use crate::{ available_packet_ids::AvailablePacketIds, packets::{ - reason_codes::{PubCompReasonCode, PubRecReasonCode, PubRelReasonCode, SubAckReasonCode, UnsubAckReasonCode}, - Packet, QoS, UnsubAck, UnsubAckProperties, {PubComp, PubCompProperties}, {PubRec, PubRecProperties}, {PubRel, PubRelProperties}, {SubAck, SubAckProperties}, + Packet, + PubComp, PubCompProperties, PubCompReasonCode, + PubRec, PubRecProperties, PubRecReasonCode, + PubRel, PubRelProperties, PubRelReasonCode, + QoS, + SubAck, SubAckProperties, SubAckReasonCode, + UnsubAck, UnsubAckProperties, UnsubAckReasonCode }, tests::test_packets::{create_connack_packet, create_puback_packet, create_publish_packet, create_subscribe_packet, create_unsubscribe_packet}, ConnectOptions, StateHandler, diff --git a/mqrstt/src/tests/test_bytes.rs b/mqrstt/src/tests/test_bytes.rs index 93a35ef..04b21af 100644 --- a/mqrstt/src/tests/test_bytes.rs +++ b/mqrstt/src/tests/test_bytes.rs @@ -2,7 +2,7 @@ use rstest::*; use bytes::BytesMut; -use crate::packets::{mqtt_traits::WireLength, Packet}; +use crate::packets::{mqtt_trait::WireLength, Packet}; fn publish_packet() -> Vec { const PUBLISH_BYTES: [u8; 79] = [ diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index 5dae121..08c39b9 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -3,8 +3,7 @@ use bytes::Bytes; use rstest::*; use crate::packets::{ - reason_codes::{DisconnectReasonCode, PubAckReasonCode}, - ConnAck, Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, Publish, PublishProperties, QoS, Subscribe, Subscription, Unsubscribe, + DisconnectReasonCode, ConnAck, Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, PubAckReasonCode, Publish, PublishProperties, QoS, Subscribe, Subscription, Unsubscribe }; fn publish_packet_1() -> Packet { @@ -20,7 +19,7 @@ fn publish_packet_1() -> Packet { topic_alias: None, response_topic: None, correlation_data: Some(b"1212".to_vec()), - subscription_identifier: vec![1], + subscription_identifiers: vec![1], user_properties: vec![], content_type: None, }, @@ -40,7 +39,7 @@ fn publish_packet_2() -> Packet { topic_alias: Some(1), response_topic: None, correlation_data: Some(b"1212".to_vec()), - subscription_identifier: vec![1], + subscription_identifiers: vec![1], user_properties: vec![], content_type: None, }, @@ -60,7 +59,7 @@ fn publish_packet_3() -> Packet { topic_alias: None, response_topic: Some("Please respond here thank you".into()), correlation_data: Some(b"5420874".to_vec()), - subscription_identifier: vec![], + subscription_identifiers: vec![], user_properties: vec![("blabla".into(), "another blabla".into())], content_type: None, }, @@ -80,7 +79,7 @@ fn publish_packet_4() -> Packet { topic_alias: Some(1), response_topic: None, correlation_data: Some(b"1212".to_vec()), - subscription_identifier: vec![1], + subscription_identifiers: vec![1], user_properties: vec![], content_type: Some("Garbage".into()), }, @@ -113,7 +112,7 @@ pub fn create_publish_packet(qos: QoS, dup: bool, retain: bool, packet_identifie topic_alias: Some(1), response_topic: None, correlation_data: Some(b"1212".to_vec()), - subscription_identifier: vec![1], + subscription_identifiers: vec![1], user_properties: vec![], content_type: Some("Garbage".into()), }, diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index 8a2be8d..f55395d 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -10,7 +10,7 @@ use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::ConnectionError; use crate::packets::error::ReadBytes; -use crate::packets::reason_codes::DisconnectReasonCode; +use crate::packets::DisconnectReasonCode; use crate::packets::{Disconnect, Packet, PacketType}; use crate::{AsyncEventHandlerMut, NetworkStatus, StateHandler}; diff --git a/mqrstt/src/tokio/stream/mod.rs b/mqrstt/src/tokio/stream/mod.rs index ad928d5..463d109 100644 --- a/mqrstt/src/tokio/stream/mod.rs +++ b/mqrstt/src/tokio/stream/mod.rs @@ -13,7 +13,7 @@ use tracing::trace; use crate::packets::ConnAck; use crate::packets::{ error::ReadBytes, - reason_codes::ConnAckReasonCode, + ConnAckReasonCode, {FixedHeader, Packet}, }; use crate::{connect_options::ConnectOptions, error::ConnectionError}; From 8b0011ded8dc338ff4fbc9f6c23b195379fd2bb5 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:02:17 +0100 Subject: [PATCH 07/45] cargo fix --- mqrstt/benches/benchmarks/mod.rs | 2 +- mqrstt/src/client.rs | 2 +- mqrstt/src/connect_options.rs | 1 - mqrstt/src/packets/connack/mod.rs | 3 +-- mqrstt/src/packets/connack/properties.rs | 2 +- mqrstt/src/packets/connect/connect_flags.rs | 6 ++---- mqrstt/src/packets/connect/connect_properties.rs | 3 +-- mqrstt/src/packets/connect/last_will.rs | 6 ++---- mqrstt/src/packets/connect/last_will_properties.rs | 5 ++--- mqrstt/src/packets/connect/mod.rs | 2 +- mqrstt/src/packets/disconnect/mod.rs | 3 +-- mqrstt/src/packets/mqtt_trait/mod.rs | 1 - mqrstt/src/packets/pubcomp/mod.rs | 1 - mqrstt/src/packets/pubcomp/properties.rs | 1 - mqrstt/src/packets/publish/mod.rs | 5 ++--- mqrstt/src/packets/publish/properties.rs | 6 ++---- mqrstt/src/packets/pubrec/mod.rs | 2 +- mqrstt/src/packets/pubrel/mod.rs | 7 +++---- mqrstt/src/packets/reason_codes.rs | 4 ---- mqrstt/src/packets/suback/mod.rs | 2 +- mqrstt/src/packets/unsuback/properties.rs | 1 - mqrstt/src/packets/unsubscribe/mod.rs | 6 +----- mqrstt/src/packets/unsubscribe/properties.rs | 4 +--- 23 files changed, 24 insertions(+), 51 deletions(-) diff --git a/mqrstt/benches/benchmarks/mod.rs b/mqrstt/benches/benchmarks/mod.rs index 3b5244f..74c1775 100644 --- a/mqrstt/benches/benchmarks/mod.rs +++ b/mqrstt/benches/benchmarks/mod.rs @@ -1,5 +1,5 @@ use bytes::{BufMut, Bytes, BytesMut}; -use mqrstt::packets::{Disconnect, Packet, Publish}; +use mqrstt::packets::{Packet, Publish}; pub mod tokio; diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index f447375..0182c2c 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -907,7 +907,7 @@ mod tests { use crate::{ error::{ClientError, PacketValidationError}, - packets::{DisconnectProperties, DisconnectReasonCode, Packet, PacketType, Publish, QoS, Subscribe, SubscribeProperties, UnsubscribeProperties}, + packets::{DisconnectProperties, DisconnectReasonCode, Packet, PacketType, Publish, QoS, SubscribeProperties, UnsubscribeProperties}, }; use super::MqttClient; diff --git a/mqrstt/src/connect_options.rs b/mqrstt/src/connect_options.rs index f83164b..224d5aa 100644 --- a/mqrstt/src/connect_options.rs +++ b/mqrstt/src/connect_options.rs @@ -1,6 +1,5 @@ use std::time::Duration; -use bytes::Bytes; use crate::util::constants::DEFAULT_RECEIVE_MAXIMUM; use crate::{ diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 1e8e1ce..a0dee84 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -7,8 +7,7 @@ pub use reason_code::ConnAckReasonCode; use super::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, - PacketType, PropertyType, QoS, VariableInteger, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, VariableInteger, }; use bytes::{Buf, BufMut}; diff --git a/mqrstt/src/packets/connack/properties.rs b/mqrstt/src/packets/connack/properties.rs index 9b32e8e..3532f85 100644 --- a/mqrstt/src/packets/connack/properties.rs +++ b/mqrstt/src/packets/connack/properties.rs @@ -3,7 +3,7 @@ use crate::packets::{ mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, PacketType, PropertyType, QoS, VariableInteger, }; -use bytes::{Buf, BufMut}; +use bytes::BufMut; crate::packets::macros::define_properties!(ConnAckProperties, SessionExpiryInterval, diff --git a/mqrstt/src/packets/connect/connect_flags.rs b/mqrstt/src/packets/connect/connect_flags.rs index b327563..08a62a8 100644 --- a/mqrstt/src/packets/connect/connect_flags.rs +++ b/mqrstt/src/packets/connect/connect_flags.rs @@ -1,11 +1,9 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Buf, BufMut}; -use crate::packets::error::ReadError; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, - PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, QoS, }; /// The connect flags describe some information related the session. diff --git a/mqrstt/src/packets/connect/connect_properties.rs b/mqrstt/src/packets/connect/connect_properties.rs index d5f3fa3..f841680 100644 --- a/mqrstt/src/packets/connect/connect_properties.rs +++ b/mqrstt/src/packets/connect/connect_properties.rs @@ -1,12 +1,11 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::packets::error::ReadError; use crate::packets::VariableInteger; use crate::packets::{ error::{DeserializeError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, - PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, + PacketType, PropertyType, WireLength, }; diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs index ede6e49..9f84901 100644 --- a/mqrstt/src/packets/connect/last_will.rs +++ b/mqrstt/src/packets/connect/last_will.rs @@ -1,11 +1,9 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; -use crate::packets::error::ReadError; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, - PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, QoS, WireLength, }; use super::{LastWillProperties, VariableInteger}; diff --git a/mqrstt/src/packets/connect/last_will_properties.rs b/mqrstt/src/packets/connect/last_will_properties.rs index 988a007..8935ec5 100644 --- a/mqrstt/src/packets/connect/last_will_properties.rs +++ b/mqrstt/src/packets/connect/last_will_properties.rs @@ -1,12 +1,11 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; -use crate::packets::error::ReadError; use crate::packets::VariableInteger; use crate::packets::{ error::{DeserializeError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, - PacketType, PropertyType, protocol_version::ProtocolVersion, QoS, WireLength, + PacketType, PropertyType, WireLength, }; crate::packets::macros::define_properties!(LastWillProperties, diff --git a/mqrstt/src/packets/connect/mod.rs b/mqrstt/src/packets/connect/mod.rs index f4ba49c..0e4a0c6 100644 --- a/mqrstt/src/packets/connect/mod.rs +++ b/mqrstt/src/packets/connect/mod.rs @@ -15,7 +15,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::packets::error::ReadError; use super::{ - error::{DeserializeError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, protocol_version::ProtocolVersion, PacketType, PropertyType, QoS, VariableInteger, WireLength + error::{DeserializeError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, protocol_version::ProtocolVersion, VariableInteger, WireLength }; /// Connect packet send by the client to the server to initialize a connection. diff --git a/mqrstt/src/packets/disconnect/mod.rs b/mqrstt/src/packets/disconnect/mod.rs index 983ca5f..bb8573b 100644 --- a/mqrstt/src/packets/disconnect/mod.rs +++ b/mqrstt/src/packets/disconnect/mod.rs @@ -5,8 +5,7 @@ mod reason_code; pub use reason_code::DisconnectReasonCode; use super::{ - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, - PacketType, PropertyType, VariableInteger, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, VariableInteger, }; #[derive(Debug, Default, Clone, PartialEq, Eq)] diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs index 2bad43c..d90d455 100644 --- a/mqrstt/src/packets/mqtt_trait/mod.rs +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -1,5 +1,4 @@ mod primitive_impl; -pub use primitive_impl::*; use std::{future::Future}; diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 286d05d..62e2aff 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -12,7 +12,6 @@ use bytes::BufMut; use super::{ error::DeserializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, - PacketType, PropertyType, }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] diff --git a/mqrstt/src/packets/pubcomp/properties.rs b/mqrstt/src/packets/pubcomp/properties.rs index 3320039..3e8be02 100644 --- a/mqrstt/src/packets/pubcomp/properties.rs +++ b/mqrstt/src/packets/pubcomp/properties.rs @@ -1,5 +1,4 @@ -use bytes::BufMut; use crate::packets::{ error::DeserializeError, diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index 729532f..139c69e 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -10,8 +10,7 @@ use crate::util::constants::MAXIMUM_TOPIC_SIZE; use super::mqtt_trait::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}; use super::VariableInteger; use super::{ - error::{DeserializeError, SerializeError}, - PacketType, PropertyType, QoS, + error::{DeserializeError, SerializeError}, QoS, }; #[derive(Debug, Default, Clone, PartialEq, Eq)] @@ -134,7 +133,7 @@ mod tests { use bytes::{BufMut, BytesMut}; use crate::packets::{ - mqtt_trait::{PacketRead, PacketWrite, WireLength}, VariableInteger, + mqtt_trait::{PacketRead, PacketWrite}, VariableInteger, }; use super::Publish; diff --git a/mqrstt/src/packets/publish/properties.rs b/mqrstt/src/packets/publish/properties.rs index ecdab62..1289cd5 100644 --- a/mqrstt/src/packets/publish/properties.rs +++ b/mqrstt/src/packets/publish/properties.rs @@ -1,13 +1,11 @@ use bytes::BufMut; -use crate::error::PacketValidationError; use crate::packets::VariableInteger; -use crate::util::constants::MAXIMUM_TOPIC_SIZE; -use crate::packets::mqtt_trait::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}; +use crate::packets::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}; use crate::packets::{ error::{DeserializeError, SerializeError}, - PacketType, PropertyType, QoS, + PacketType, PropertyType, }; crate::packets::macros::define_properties!(PublishProperties, diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 48ddbe4..835bcec 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -9,7 +9,7 @@ use bytes::BufMut; use super::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, PacketType, PropertyType, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index c320d6a..cf35314 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -9,8 +9,7 @@ use bytes::BufMut; use super::{ error::{DeserializeError, ReadError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, - PacketType, PropertyType + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength} }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -122,8 +121,8 @@ mod tests { use crate::packets::{ mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, pubrel::{PubRel, PubRelProperties}, PropertyType, PubRelReasonCode, VariableInteger }; - use bytes::{Buf, BufMut, Bytes, BytesMut}; - use tokio::{io::ReadBuf, stream}; + use bytes::{BufMut, Bytes, BytesMut}; + #[test] fn test_wire_len() { diff --git a/mqrstt/src/packets/reason_codes.rs b/mqrstt/src/packets/reason_codes.rs index f86a8a3..842383d 100644 --- a/mqrstt/src/packets/reason_codes.rs +++ b/mqrstt/src/packets/reason_codes.rs @@ -1,9 +1,5 @@ -use std::default; -use bytes::{Buf, BufMut}; -use super::error::DeserializeError; -use super::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}; // #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index 2702066..b4e34a3 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -8,7 +8,7 @@ use bytes::BufMut; use super::{ error::{SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, }; diff --git a/mqrstt/src/packets/unsuback/properties.rs b/mqrstt/src/packets/unsuback/properties.rs index bd90e8e..e16ee0c 100644 --- a/mqrstt/src/packets/unsuback/properties.rs +++ b/mqrstt/src/packets/unsuback/properties.rs @@ -1,4 +1,3 @@ -use bytes::BufMut; use crate::packets::error::{DeserializeError}; use crate::packets::mqtt_trait::{MqttRead, MqttWrite, WireLength}; diff --git a/mqrstt/src/packets/unsubscribe/mod.rs b/mqrstt/src/packets/unsubscribe/mod.rs index f6e3bfc..e9c8d3d 100644 --- a/mqrstt/src/packets/unsubscribe/mod.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -6,11 +6,7 @@ use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use crate::packets::mqtt_trait::MqttAsyncRead; use super::VariableInteger; -use super::{ - error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}, - PacketType, PropertyType, -}; +use super::mqtt_trait::{MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; use bytes::BufMut; #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/mqrstt/src/packets/unsubscribe/properties.rs b/mqrstt/src/packets/unsubscribe/properties.rs index beab1d5..0487f7b 100644 --- a/mqrstt/src/packets/unsubscribe/properties.rs +++ b/mqrstt/src/packets/unsubscribe/properties.rs @@ -1,12 +1,10 @@ use crate::packets::VariableInteger; -use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use crate::packets::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, PacketType, PropertyType, }; -use bytes::BufMut; crate::packets::macros::define_properties!(UnsubscribeProperties, UserProperty); From a2cb7498490cb62346ac6399196e797d8c5333b4 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:05:31 +0100 Subject: [PATCH 08/45] removed unused imports and make items private --- mqrstt/src/client.rs | 2 +- mqrstt/src/packets/connack/mod.rs | 2 +- mqrstt/src/packets/connack/properties.rs | 2 +- mqrstt/src/packets/connect/connect_properties.rs | 2 +- mqrstt/src/packets/connect/last_will.rs | 2 +- mqrstt/src/packets/connect/last_will_properties.rs | 2 +- mqrstt/src/packets/mod.rs | 2 +- mqrstt/src/packets/pubcomp/properties.rs | 2 +- mqrstt/src/packets/publish/properties.rs | 2 +- mqrstt/src/packets/unsubscribe/properties.rs | 2 +- mqrstt/src/smol/network.rs | 2 +- mqrstt/src/tokio/network.rs | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index 0182c2c..579a3f7 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -900,7 +900,7 @@ impl MqttClient { } } -#[cfg(any(feature = "tokio", feature = "smol", feature = "quic"))] +#[cfg(any(feature = "tokio", feature = "smol"))] #[cfg(test)] mod tests { use async_channel::Receiver; diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index a0dee84..56d0259 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -44,7 +44,7 @@ impl PacketRead for ConnAck { } impl PacketAsyncRead for ConnAck where S: tokio::io::AsyncReadExt + Unpin { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + fn async_read(_: u8, _: usize, stream: &mut S) -> impl std::future::Future> { async move { let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; diff --git a/mqrstt/src/packets/connack/properties.rs b/mqrstt/src/packets/connack/properties.rs index 3532f85..91b1cf4 100644 --- a/mqrstt/src/packets/connack/properties.rs +++ b/mqrstt/src/packets/connack/properties.rs @@ -1,6 +1,6 @@ use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, QoS, VariableInteger, }; use bytes::BufMut; diff --git a/mqrstt/src/packets/connect/connect_properties.rs b/mqrstt/src/packets/connect/connect_properties.rs index f841680..ca25dc5 100644 --- a/mqrstt/src/packets/connect/connect_properties.rs +++ b/mqrstt/src/packets/connect/connect_properties.rs @@ -4,7 +4,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::packets::VariableInteger; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + mqtt_trait::{MqttRead, MqttWrite}, PacketType, PropertyType, WireLength, }; diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs index 9f84901..b195b32 100644 --- a/mqrstt/src/packets/connect/last_will.rs +++ b/mqrstt/src/packets/connect/last_will.rs @@ -3,7 +3,7 @@ use bytes::{Bytes, BytesMut}; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, QoS, WireLength, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, QoS, WireLength, }; use super::{LastWillProperties, VariableInteger}; diff --git a/mqrstt/src/packets/connect/last_will_properties.rs b/mqrstt/src/packets/connect/last_will_properties.rs index 8935ec5..8051453 100644 --- a/mqrstt/src/packets/connect/last_will_properties.rs +++ b/mqrstt/src/packets/connect/last_will_properties.rs @@ -4,7 +4,7 @@ use bytes::{BufMut, Bytes, BytesMut}; use crate::packets::VariableInteger; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + mqtt_trait::{MqttRead, MqttWrite}, PacketType, PropertyType, WireLength, }; diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 78c86e6..61f4fe2 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -145,7 +145,7 @@ impl Packet { } Packet::UnsubAck(_) => { unreachable!(); - buf.put_u8(0b1011_0000); + // buf.put_u8(0b1011_0000); } Packet::PingReq => { buf.put_u8(0b1100_0000); diff --git a/mqrstt/src/packets/pubcomp/properties.rs b/mqrstt/src/packets/pubcomp/properties.rs index 3e8be02..3cbb8cf 100644 --- a/mqrstt/src/packets/pubcomp/properties.rs +++ b/mqrstt/src/packets/pubcomp/properties.rs @@ -2,7 +2,7 @@ use crate::packets::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger, }; diff --git a/mqrstt/src/packets/publish/properties.rs b/mqrstt/src/packets/publish/properties.rs index 1289cd5..9d7790e 100644 --- a/mqrstt/src/packets/publish/properties.rs +++ b/mqrstt/src/packets/publish/properties.rs @@ -2,7 +2,7 @@ use bytes::BufMut; use crate::packets::VariableInteger; -use crate::packets::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}; +use crate::packets::mqtt_trait::{MqttRead, MqttWrite, WireLength}; use crate::packets::{ error::{DeserializeError, SerializeError}, PacketType, PropertyType, diff --git a/mqrstt/src/packets/unsubscribe/properties.rs b/mqrstt/src/packets/unsubscribe/properties.rs index 0487f7b..41c8091 100644 --- a/mqrstt/src/packets/unsubscribe/properties.rs +++ b/mqrstt/src/packets/unsubscribe/properties.rs @@ -2,7 +2,7 @@ use crate::packets::VariableInteger; use crate::packets::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, }; diff --git a/mqrstt/src/smol/network.rs b/mqrstt/src/smol/network.rs index 29ab074..6497dcd 100644 --- a/mqrstt/src/smol/network.rs +++ b/mqrstt/src/smol/network.rs @@ -38,7 +38,7 @@ pub struct Network { } impl Network { - pub fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { + pub(crate) fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { let state_handler = StateHandler::new(&options, apkids); Self { handler: PhantomData, diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index f55395d..d526988 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -36,7 +36,7 @@ pub struct Network { } impl Network { - pub fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { + pub(crate) fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { Self { handler_helper: PhantomData, handler: PhantomData, From 1e950748b8df3697c3aa0e583c5958a06fae21ae Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Wed, 20 Nov 2024 23:37:27 +0100 Subject: [PATCH 09/45] add more packet async read and fix some small stuff --- .github/workflows/rust.yml | 11 +- mqrstt/Cargo.toml | 2 +- mqrstt/benches/benchmarks/mod.rs | 2 +- mqrstt/src/client.rs | 15 +- mqrstt/src/packets/error.rs | 8 + .../src/packets/macros/properties_macros.rs | 544 +++++++++--------- mqrstt/src/packets/mod.rs | 238 +++----- mqrstt/src/packets/mqtt_trait/mod.rs | 4 + mqrstt/src/packets/primitive/fixed_header.rs | 22 +- mqrstt/src/packets/primitive/property_type.rs | 15 +- .../src/packets/primitive/variable_integer.rs | 71 ++- mqrstt/src/packets/pubcomp/mod.rs | 6 +- mqrstt/src/packets/publish/mod.rs | 59 +- mqrstt/src/packets/pubrec/mod.rs | 33 +- mqrstt/src/packets/pubrel/mod.rs | 32 +- mqrstt/src/packets/reason_codes.rs | 254 -------- mqrstt/src/packets/suback/mod.rs | 6 +- .../{subscribe.rs => subscribe/mod.rs} | 140 ++--- mqrstt/src/packets/subscribe/properties.rs | 94 +++ mqrstt/src/packets/subscribe/reason_code.rs | 0 mqrstt/src/packets/unsuback/mod.rs | 43 +- mqrstt/src/packets/unsubscribe/mod.rs | 4 +- mqrstt/src/tests/test_packets.rs | 141 ++++- 23 files changed, 863 insertions(+), 881 deletions(-) delete mode 100644 mqrstt/src/packets/reason_codes.rs rename mqrstt/src/packets/{subscribe.rs => subscribe/mod.rs} (79%) delete mode 100644 mqrstt/src/packets/subscribe/reason_code.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index da47ec8..487fc0f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -1,10 +1,17 @@ name: Rust + +# run on push and pull request to main and release branches on: push: - branches: [ "main" ] + branches: + - main + - release/* pull_request: - branches: [ "main" ] + branches: + - main + - release/* + env: CARGO_TERM_COLOR: always diff --git a/mqrstt/Cargo.toml b/mqrstt/Cargo.toml index b433692..54c19f6 100644 --- a/mqrstt/Cargo.toml +++ b/mqrstt/Cargo.toml @@ -73,7 +73,7 @@ rustls-pemfile = { version = "1.0.3" } webpki = { version = "0.22.4" } async-rustls = { version = "0.4.1" } tokio-rustls = "0.24.1" -rstest = "0.18.2" +rstest = "0.23.0" rand = "0.8.5" diff --git a/mqrstt/benches/benchmarks/mod.rs b/mqrstt/benches/benchmarks/mod.rs index 74c1775..18fcb67 100644 --- a/mqrstt/benches/benchmarks/mod.rs +++ b/mqrstt/benches/benchmarks/mod.rs @@ -66,7 +66,7 @@ fn very_large_publish(id: u16, repeat: usize) -> Packet { topic: "BlaBla".into(), packet_identifier: Some(id), publish_properties: Default::default(), - payload: Bytes::from_iter("ping".repeat(repeat).into_bytes()), + payload: b"ping".repeat(repeat).to_vec(), }; Packet::Publish(publ) diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index 579a3f7..4123d73 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -1,5 +1,4 @@ use async_channel::{Receiver, Sender}; -use bytes::Bytes; #[cfg(feature = "logs")] use tracing::info; @@ -158,7 +157,7 @@ impl MqttClient { /// use mqrstt::packets::{SubscribeProperties, SubscriptionOptions, RetainHandling}; /// /// let sub_properties = SubscribeProperties{ - /// subscription_id: Some(1), + /// subscription_identifier: Some(1), /// user_properties: vec![], /// }; /// @@ -236,7 +235,7 @@ impl MqttClient { /// /// # }); /// ``` - pub async fn publish, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { + pub async fn publish, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?), @@ -308,7 +307,7 @@ impl MqttClient { /// # }); /// # let _network = std::hint::black_box(network); /// ``` - pub async fn publish_with_properties, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { + pub async fn publish_with_properties, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?), @@ -563,7 +562,7 @@ impl MqttClient { /// use mqrstt::packets::{SubscribeProperties, SubscriptionOptions, RetainHandling}; /// /// let sub_properties = SubscribeProperties{ - /// subscription_id: Some(1), + /// subscription_identifier: Some(1), /// user_properties: vec![], /// }; /// # let sub_properties_clone = sub_properties.clone(); @@ -642,7 +641,7 @@ impl MqttClient { /// /// # }); /// ``` - pub fn publish_blocking, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { + pub fn publish_blocking, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?), @@ -715,7 +714,7 @@ impl MqttClient { /// /// # }); /// ``` - pub fn publish_with_properties_blocking, P: Into>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { + pub fn publish_with_properties_blocking, P: Into>>(&self, topic: T, qos: QoS, retain: bool, payload: P, properties: PublishProperties) -> Result<(), ClientError> { let pkid = match qos { QoS::AtMostOnce => None, _ => Some(self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?), @@ -957,7 +956,7 @@ mod tests { let (mqtt_client, client_to_handler_r, to_network_r) = create_new_test_client(); let sub_properties = SubscribeProperties{ - subscription_id: Some(1), + subscription_identifier: Some(1), user_properties: vec![], }; diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index cb423ed..7585b3a 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -4,6 +4,14 @@ use thiserror::Error; use super::{PacketType, PropertyType}; +#[derive(Error, Debug)] +pub enum WriteError { + #[error("{0}")] + SerializeError(#[from] SerializeError), + #[error("{0}")] + IoError(#[from] std::io::Error), +} + #[derive(Error, Debug)] pub enum ReadError{ #[error("{0}")] diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index ef1ed6c..d6a01c1 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -17,7 +17,7 @@ macro_rules! define_properties { read_property_bytes += read_bytes; match prop { $( - PropertyType::$prop_variant => $crate::packets::macros::properties_read_matches!(stream, properties, read_property_bytes, PropertyType::$prop_variant), + $crate::packets::macros::properties_read_match_branch_name!($prop_variant) => $crate::packets::macros::properties_read_match_branch_body!(stream, properties, read_property_bytes, PropertyType::$prop_variant), )* e => return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), } @@ -279,286 +279,267 @@ macro_rules! properties_struct { ); } -macro_rules! properties_read_matches { +macro_rules! properties_read_match_branch_body { + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::PayloadFormatIndicator) => {{ + if $properties.payload_format_indicator.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::PayloadFormatIndicator, + ))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.payload_format_indicator = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MessageExpiryInterval) => {{ + if $properties.message_expiry_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::MessageExpiryInterval, + ))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.message_expiry_interval = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ContentType) => {{ + if $properties.content_type.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ContentType))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.content_type = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseTopic) => {{ + if $properties.response_topic.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.response_topic = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::CorrelationData) => {{ + if $properties.correlation_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::CorrelationData))); + } + let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.correlation_data = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifier) => {{ + let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_identifier = Some(prop_body as u32); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ListSubscriptionIdentifier) => {{ + let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_identifiers.push(prop_body as u32); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SessionExpiryInterval) => {{ + if $properties.session_expiry_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::SessionExpiryInterval, + ))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.session_expiry_interval = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AssignedClientIdentifier) => {{ + if $properties.assigned_client_id.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::AssignedClientIdentifier, + ))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.assigned_client_id = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerKeepAlive) => {{ + if $properties.server_keep_alive.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.server_keep_alive = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationMethod) => {{ + if $properties.authentication_method.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::AuthenticationMethod, + ))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.authentication_method = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationData) => {{ + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::AuthenticationData, + ))); + } + let (prop_body, read_bytes) = Vec::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.authentication_data = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestResponseInformation) => {{ + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::RequestResponseInformation, + ))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.request_problem_information = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestProblemInformation) => {{ + if $properties.authentication_data.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::RequestProblemInformation, + ))); + } + let (prop_body, read_bytes) = u8::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.request_problem_information = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WillDelayInterval) => {{ + if $properties.will_delay_interval.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::WillDelayInterval, + ))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.will_delay_interval = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseInformation) => {{ + if $properties.response_info.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::ResponseInformation, + ))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.response_info = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerReference) => {{ + if $properties.server_reference.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerReference))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.server_reference = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReasonString) => {{ + if $properties.reason_string.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); + } + let (prop_body, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.reason_string = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReceiveMaximum) => {{ + if $properties.receive_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.receive_maximum = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAliasMaximum) => {{ + if $properties.topic_alias_maximum.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::TopicAliasMaximum, + ))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.topic_alias_maximum = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAlias) => {{ + if $properties.topic_alias.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::MessageExpiryInterval, + ))); + } + let (prop_body, read_bytes) = u16::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.topic_alias = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumQos) => {{ + if $properties.maximum_qos.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MaximumQos))); + } + let (prop_body, read_bytes) = $crate::packets::QoS::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.maximum_qos = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RetainAvailable) => {{ + if $properties.retain_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.retain_available = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::UserProperty) => {{ + let (prop_body_key, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; + let (prop_body_value, read_bytes) = Box::::async_read($stream).await?; + $read_property_bytes += read_bytes; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::PayloadFormatIndicator) => { - { - if $properties.payload_format_indicator.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::PayloadFormatIndicator))); - } - let (prop_body, read_bytes) = u8::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.payload_format_indicator = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MessageExpiryInterval) => { - { - if $properties.message_expiry_interval.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval))); - } - let (prop_body, read_bytes) = u32::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.message_expiry_interval = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ContentType) => { - { - if $properties.content_type.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ContentType))); - } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.content_type = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseTopic) => { - { - if $properties.response_topic.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseTopic))); - } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.response_topic = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::CorrelationData) => { - { - if $properties.correlation_data.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::CorrelationData))); - } - let (prop_body, read_bytes) = Vec::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.correlation_data = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifier) => { - { - let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; - $read_property_bytes += read_bytes; - $properties.subscription_identifier = Some(prop_body as u32); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ListSubscriptionIdentifier) => { - { - let (prop_body, read_bytes) = ::read_async_variable_integer($stream).await?; - $read_property_bytes += read_bytes; - $properties.subscription_identifiers.push(prop_body as u32); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SessionExpiryInterval) => { - { - if $properties.session_expiry_interval.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval))); - } - let (prop_body, read_bytes) = u32::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.session_expiry_interval = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AssignedClientIdentifier) => { - { - if $properties.assigned_client_id.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier))); - } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.assigned_client_id = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerKeepAlive) => { - { - if $properties.server_keep_alive.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerKeepAlive))); - } - let (prop_body, read_bytes) = u16::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.server_keep_alive = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationMethod) => { - { - if $properties.authentication_method.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod))); - } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.authentication_method = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AuthenticationData) => { - { - if $properties.authentication_data.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData))); - } - let (prop_body, read_bytes) = Vec::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.authentication_data = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestResponseInformation) => { - { - if $properties.authentication_data.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RequestResponseInformation))); - } - let (prop_body, read_bytes) = u8::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.request_problem_information = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestProblemInformation) => { - { - if $properties.authentication_data.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RequestProblemInformation))); - } - let (prop_body, read_bytes) = u8::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.request_problem_information = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WillDelayInterval) => { - { - if $properties.will_delay_interval.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::WillDelayInterval))); - } - let (prop_body, read_bytes) = u32::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.will_delay_interval = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ResponseInformation) => { - { - if $properties.response_info.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ResponseInformation))); - } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.response_info = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerReference) => { - { - if $properties.server_reference.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ServerReference))); - } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.server_reference = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReasonString) => { - { - if $properties.reason_string.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); - } - let (prop_body, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.reason_string = Some(prop_body); - } + $properties.user_properties.push((prop_body_key, prop_body_value)) + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumPacketSize) => {{ + if $properties.maximum_packet_size.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); + } + let (prop_body, read_bytes) = u32::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.maximum_packet_size = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WildcardSubscriptionAvailable) => {{ + if $properties.wildcards_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::WildcardSubscriptionAvailable, + ))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.wildcards_available = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifierAvailable) => {{ + if $properties.subscription_ids_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::SubscriptionIdentifierAvailable, + ))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.subscription_ids_available = Some(prop_body); + }}; + ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SharedSubscriptionAvailable) => {{ + if $properties.shared_subscription_available.is_some() { + return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( + PropertyType::SharedSubscriptionAvailable, + ))); + } + let (prop_body, read_bytes) = bool::async_read($stream).await?; + $read_property_bytes += read_bytes; + $properties.shared_subscription_available = Some(prop_body); + }}; +} + +macro_rules! properties_read_match_branch_name { + (ListSubscriptionIdentifier) => { + PropertyType::SubscriptionIdentifier }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ReceiveMaximum) => { - { - if $properties.receive_maximum.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReceiveMaximum))); - } - let (prop_body, read_bytes) = u16::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.receive_maximum = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAliasMaximum) => { - { - if $properties.topic_alias_maximum.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::TopicAliasMaximum))); - } - let (prop_body, read_bytes) = u16::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.topic_alias_maximum = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::TopicAlias) => { - { - if $properties.topic_alias.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MessageExpiryInterval))); - } - let (prop_body, read_bytes) = u16::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.topic_alias = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumQos) => { - { - if $properties.maximum_qos.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::MaximumQos))); - } - let (prop_body, read_bytes) = $crate::packets::QoS::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.maximum_qos = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RetainAvailable) => { - { - if $properties.retain_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); - } - let (prop_body, read_bytes) = bool::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.retain_available = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::UserProperty) => { - { - let (prop_body_key, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - let (prop_body_value, read_bytes) = Box::::async_read($stream).await?; - $read_property_bytes += read_bytes; - - $properties.user_properties.push((prop_body_key, prop_body_value)) - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::MaximumPacketSize) => { - { - if $properties.maximum_packet_size.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::RetainAvailable))); - } - let (prop_body, read_bytes) = u32::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.maximum_packet_size = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::WildcardSubscriptionAvailable) => { - { - if $properties.wildcards_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::WildcardSubscriptionAvailable))); - } - let (prop_body, read_bytes) = bool::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.wildcards_available = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SubscriptionIdentifierAvailable) => { - { - if $properties.subscription_ids_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifierAvailable))); - } - let (prop_body, read_bytes) = bool::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.subscription_ids_available = Some(prop_body); - } - }; - ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::SharedSubscriptionAvailable) => { - { - if $properties.shared_subscription_available.is_some() { - return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::SharedSubscriptionAvailable))); - } - let (prop_body, read_bytes) = bool::async_read($stream).await?; - $read_property_bytes += read_bytes; - $properties.shared_subscription_available = Some(prop_body); - } + ($name:ident) => { + PropertyType::$name }; } -macro_rules! properties_wire_length{ - +macro_rules! properties_wire_length { ($self:ident, $len:ident, PropertyType::PayloadFormatIndicator) => { if $self.payload_format_indicator.is_some() { $len += 2; @@ -590,7 +571,7 @@ macro_rules! properties_wire_length{ } }; ($self:ident, $len:ident, PropertyType::ListSubscriptionIdentifier) => { - for sub_id in &($self.subscription_identifiers) { + for sub_id in &($self.subscription_identifiers) { $len += 1 + crate::packets::primitive::VariableInteger::variable_integer_len(sub_id); } }; @@ -703,12 +684,13 @@ macro_rules! properties_wire_length{ $len += 1 + 1; } }; - ($self:ident, $len:ident, $unknown:ident) => ( + ($self:ident, $len:ident, $unknown:ident) => { compile_error!(concat!("Unknown property: ", stringify!($unknown))); - ); + }; } pub(crate) use define_properties; +pub(crate) use properties_read_match_branch_body; +pub(crate) use properties_read_match_branch_name; pub(crate) use properties_struct; -pub(crate) use properties_read_matches; -pub(crate) use properties_wire_length; \ No newline at end of file +pub(crate) use properties_wire_length; diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 61f4fe2..349e1b0 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -1,6 +1,5 @@ pub mod error; pub mod mqtt_trait; -pub mod reason_codes; mod macros; @@ -19,6 +18,8 @@ mod unsuback; mod unsubscribe; mod primitive; +use error::ReadError; +use mqtt_trait::PacketAsyncRead; pub use primitive::*; pub use auth::*; @@ -190,6 +191,41 @@ impl Packet { Ok(packet) } + pub(crate) async fn async_read(header: FixedHeader, stream: &mut S) -> Result + where + S: tokio::io::AsyncRead + Unpin, + { + let packet = match header.packet_type { + PacketType::Connect => Packet::Connect(Connect::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::ConnAck => Packet::ConnAck(ConnAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Publish => Packet::Publish(Publish::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubAck => Packet::PubAck(PubAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubRec => Packet::PubRec(PubRec::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubRel => Packet::PubRel(PubRel::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PubComp => Packet::PubComp(PubComp::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Subscribe => Packet::Subscribe(Subscribe::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::SubAck => Packet::SubAck(SubAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Unsubscribe => Packet::Unsubscribe(Unsubscribe::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::UnsubAck => Packet::UnsubAck(UnsubAck::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::PingReq => Packet::PingReq, + PacketType::PingResp => Packet::PingResp, + PacketType::Disconnect => Packet::Disconnect(Disconnect::async_read(header.flags, header.remaining_length, stream).await?.0), + PacketType::Auth => Packet::Auth(Auth::async_read(header.flags, header.remaining_length, stream).await?.0), + }; + Ok(packet) + } + + #[cfg(test)] + pub(crate) async fn async_read_from_buffer(stream: &mut S) -> Result + where + S: tokio::io::AsyncRead + Unpin, + { + let (header, _) = FixedHeader::async_read(stream).await?; + + Ok(Packet::async_read(header, stream).await?) + } + + #[cfg(test)] pub(crate) fn read_from_buffer(buffer: &mut BytesMut) -> Result> { let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?; if header.remaining_length + header_length > buffer.len() { @@ -253,6 +289,7 @@ pub enum PacketType { Auth, } impl PacketType { + #[inline] const fn from_first_byte(value: u8) -> Result<(Self, u8), DeserializeError> { match (value >> 4, value & 0x0f) { (0b0001, 0) => Ok((PacketType::Connect, 0)), @@ -283,155 +320,29 @@ impl std::fmt::Display for PacketType { #[cfg(test)] mod tests { - use bytes::{Bytes, BytesMut}; + use bytes::BytesMut; use crate::packets::connack::{ConnAck, ConnAckFlags, ConnAckProperties}; use crate::packets::disconnect::{Disconnect, DisconnectProperties}; use crate::packets::QoS; - use crate::packets::publish::{Publish, PublishProperties}; - use crate::packets::pubrel::{PubRel, PubRelProperties}; use crate::packets::connack::ConnAckReasonCode; use crate::packets::disconnect::DisconnectReasonCode; + use crate::packets::publish::{Publish, PublishProperties}; use crate::packets::pubrel::PubRelReasonCode; + use crate::packets::pubrel::{PubRel, PubRelProperties}; use crate::packets::Packet; - #[test] - fn test_connack_read() { - let connack = [ - 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, - ]; - let mut buf = BytesMut::new(); - buf.extend(connack); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = ConnAck { - connack_flags: ConnAckFlags { session_present: true }, - reason_code: ConnAckReasonCode::Success, - connack_properties: ConnAckProperties { - session_expiry_interval: None, - receive_maximum: None, - maximum_qos: None, - retain_available: Some(true), - maximum_packet_size: Some(1048576), - assigned_client_id: None, - topic_alias_maximum: Some(65535), - reason_string: None, - user_properties: vec![], - wildcards_available: Some(true), - subscription_ids_available: Some(true), - shared_subscription_available: Some(true), - server_keep_alive: None, - response_info: None, - server_reference: None, - authentication_method: None, - authentication_data: None, - }, - }; - - assert_eq!(Packet::ConnAck(expected), res); - } - - #[test] - fn test_disconnect_read() { - let packet = [0xe0, 0x02, 0x8e, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = Disconnect { - reason_code: DisconnectReasonCode::SessionTakenOver, - properties: DisconnectProperties { - session_expiry_interval: None, - reason_string: None, - user_properties: vec![], - server_reference: None, - }, - }; - - assert_eq!(Packet::Disconnect(expected), res); - } - - #[test] - fn test_pingreq_read_write() { - let packet = [0xc0, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - assert_eq!(Packet::PingReq, res); - - buf.clear(); - Packet::PingReq.write(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), packet); - } - - #[test] - fn test_pingresp_read_write() { - let packet = [0xd0, 0x00]; - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - assert_eq!(Packet::PingResp, res); - - buf.clear(); - Packet::PingResp.write(&mut buf).unwrap(); - assert_eq!(buf.to_vec(), packet); - } - - #[test] - fn test_publish_read() { - let packet = [ - 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, 0x01, 0x09, 0x00, - 0x04, 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, - ]; - - let mut buf = BytesMut::new(); - buf.extend(packet); - - let res = Packet::read_from_buffer(&mut buf); - assert!(res.is_ok()); - let res = res.unwrap(); - - let expected = Publish { - dup: false, - qos: QoS::ExactlyOnce, - retain: true, - topic: "test/123/test/blabla".into(), - packet_identifier: Some(13779), - publish_properties: PublishProperties { - payload_format_indicator: Some(1), - message_expiry_interval: None, - topic_alias: None, - response_topic: None, - correlation_data: Some(b"1212".to_vec()), - subscription_identifiers: vec![1], - user_properties: vec![], - content_type: None, - }, - payload: Bytes::from_static(b""), - }; - - assert_eq!(Packet::Publish(expected), res); - } - - #[test] - fn test_pubrel_read_write() { - let bytes = [0x62, 0x03, 0x35, 0xd3, 0x00]; + use crate::tests::test_packets::{disconnect_case, ping_req_case, ping_resp_case, publish_case, pubrel_case, pubrel_smallest_case}; + #[rstest::rstest] + #[case(disconnect_case())] + #[case(ping_req_case())] + #[case(ping_resp_case())] + #[case(publish_case())] + #[case(pubrel_case())] + #[case(pubrel_smallest_case())] + fn test_read_write_cases(#[case] (bytes, expected_packet): (&[u8], Packet)) { let mut buffer = BytesMut::from_iter(bytes); let res = Packet::read_from_buffer(&mut buffer); @@ -440,52 +351,37 @@ mod tests { let packet = res.unwrap(); - let expected = PubRel { - packet_identifier: 13779, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties { - reason_string: None, - user_properties: vec![], - }, - }; - - assert_eq!(packet, Packet::PubRel(expected)); + assert_eq!(packet, expected_packet); buffer.clear(); packet.write(&mut buffer).unwrap(); - // The input is not in the smallest possible format but when writing we do expect it to be in the smallest possible format. - assert_eq!(buffer.to_vec(), [0x62, 0x02, 0x35, 0xd3].to_vec()) + assert_eq!(buffer.to_vec(), bytes.to_vec()) } - #[test] - fn test_pubrel_read_smallest_format() { - let bytes = [0x62, 0x02, 0x35, 0xd3]; - - let mut buffer = BytesMut::from_iter(bytes); + #[rstest::rstest] + #[case(disconnect_case())] + #[case(ping_req_case())] + #[case(ping_resp_case())] + #[case(publish_case())] + #[case(pubrel_case())] + #[case(pubrel_smallest_case())] + #[tokio::test] + async fn test_async_read_write(#[case] (mut bytes, expected_packet): (&[u8], Packet)) { + // let mut buffer = BytesMut::from(bytes); - let res = Packet::read_from_buffer(&mut buffer); + let res = Packet::async_read_from_buffer(&mut bytes).await; + dbg!(&res); assert!(res.is_ok()); let packet = res.unwrap(); - let expected = PubRel { - packet_identifier: 13779, - reason_code: PubRelReasonCode::Success, - properties: PubRelProperties { - reason_string: None, - user_properties: vec![], - }, - }; - - assert_eq!(packet, Packet::PubRel(expected)); + assert_eq!(packet, expected_packet); - buffer.clear(); + // packet.write(&mut buffer).unwrap(); - packet.write(&mut buffer).unwrap(); - - assert_eq!(buffer.to_vec(), bytes.to_vec()) + // assert_eq!() } } diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs index d90d455..0bd95be 100644 --- a/mqrstt/src/packets/mqtt_trait/mod.rs +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -68,6 +68,10 @@ pub(crate) trait PacketAsyncRead: Sized where S: tokio::io::AsyncReadExt + U fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; } +pub(crate) trait PacketAsyncWrite: Sized where S: tokio::io::AsyncWriteExt + Unpin { + fn async_write(&self, stream: &mut S) -> impl Future>; +} + pub(crate) trait PacketWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; } diff --git a/mqrstt/src/packets/primitive/fixed_header.rs b/mqrstt/src/packets/primitive/fixed_header.rs index 709c565..cb49a93 100644 --- a/mqrstt/src/packets/primitive/fixed_header.rs +++ b/mqrstt/src/packets/primitive/fixed_header.rs @@ -1,5 +1,11 @@ use core::slice::Iter; -use crate::packets::{error::{DeserializeError, ReadBytes}, PacketType}; + +use tokio::io::AsyncReadExt; + +use crate::packets::{ + error::{DeserializeError, ReadBytes}, + PacketType, +}; use super::read_fixed_header_rem_len; @@ -37,4 +43,16 @@ impl FixedHeader { Ok((Self { packet_type, flags, remaining_length }, header_length)) } -} \ No newline at end of file + + pub async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> + where + S: tokio::io::AsyncRead + Unpin, + { + let first_byte = stream.read_u8().await?; + + let (packet_type, flags) = PacketType::from_first_byte(first_byte)?; + + let (remaining_length, length) = super::async_read_fixed_header_rem_len(stream).await?; + Ok((Self { packet_type, flags, remaining_length }, 1 + length)) + } +} diff --git a/mqrstt/src/packets/primitive/property_type.rs b/mqrstt/src/packets/primitive/property_type.rs index 1091293..54bd00f 100644 --- a/mqrstt/src/packets/primitive/property_type.rs +++ b/mqrstt/src/packets/primitive/property_type.rs @@ -1,6 +1,9 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::packets::{error::{DeserializeError, ReadError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}}; +use crate::packets::{ + error::{DeserializeError, ReadError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum PropertyType { @@ -10,8 +13,6 @@ pub enum PropertyType { ResponseTopic = 8, CorrelationData = 9, SubscriptionIdentifier = 11, - /// Alternative to SubscriptionIdentifier, dummy property to encode multiple subscription identifiers - ListSubscriptionIdentifier, SessionExpiryInterval = 17, AssignedClientIdentifier = 18, ServerKeepAlive = 19, @@ -81,7 +82,6 @@ impl From<&PropertyType> for u8 { PropertyType::ResponseTopic => 8, PropertyType::CorrelationData => 9, PropertyType::SubscriptionIdentifier => 11, - PropertyType::ListSubscriptionIdentifier => 11, PropertyType::SessionExpiryInterval => 17, PropertyType::AssignedClientIdentifier => 18, PropertyType::ServerKeepAlive => 19, @@ -123,7 +123,10 @@ impl MqttRead for PropertyType { } } -impl MqttAsyncRead for PropertyType where T: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for PropertyType +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { match buf.read_u8().await { Ok(t) => Ok((t.try_into()?, 1)), @@ -137,4 +140,4 @@ impl MqttWrite for PropertyType { buf.put_u8(self.into()); Ok(()) } -} \ No newline at end of file +} diff --git a/mqrstt/src/packets/primitive/variable_integer.rs b/mqrstt/src/packets/primitive/variable_integer.rs index c27bf8c..673fe36 100644 --- a/mqrstt/src/packets/primitive/variable_integer.rs +++ b/mqrstt/src/packets/primitive/variable_integer.rs @@ -1,7 +1,9 @@ +use crate::packets::error::{DeserializeError, ReadBytes, ReadError, SerializeError}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use core::slice::Iter; use std::future::Future; -use crate::packets::error::{DeserializeError, ReadBytes, ReadError, SerializeError}; + +use tokio::io::AsyncReadExt; pub(crate) fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { let mut integer = 0; @@ -22,18 +24,31 @@ pub(crate) fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usi Err(ReadBytes::Err(DeserializeError::MalformedPacket)) } +pub(crate) async fn async_read_fixed_header_rem_len(stream: &mut S) -> Result<(usize, usize), ReadError> +where + S: tokio::io::AsyncRead + Unpin, +{ + let mut integer = 0; + let mut length = 0; + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) +} pub(crate) trait VariableInteger: Sized { fn variable_integer_len(&self) -> usize; fn write_variable_integer(&self, buf: &mut BytesMut) -> Result; fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError>; fn read_async_variable_integer(stream: &mut S) -> impl Future>; - } - - impl VariableInteger for usize { - fn variable_integer_len(&self) -> usize { if *self >= 2_097_152 { 4 @@ -45,14 +60,14 @@ impl VariableInteger for usize { 1 } } - + fn write_variable_integer(&self, buf: &mut BytesMut) -> Result { if *self > 268_435_455 { return Err(SerializeError::VariableIntegerOverflow(*self as usize)); } - + let mut write = *self; - + for i in 0..4 { let mut byte = (write % 128) as u8; write /= 128; @@ -61,43 +76,43 @@ impl VariableInteger for usize { } buf.put_u8(byte); if write == 0 { - return Ok(i+1); + return Ok(i + 1); } } Err(SerializeError::VariableIntegerOverflow(*self as usize)) } - + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { let mut integer = 0; let mut length = 0; - + for i in 0..4 { if buf.is_empty() { return Err(DeserializeError::MalformedPacket); } length += 1; let byte = buf.get_u8(); - + integer += (byte as usize & 0x7f) << (7 * i); - + if (byte & 0b1000_0000) == 0 { return Ok((integer, length)); } } Err(DeserializeError::MalformedPacket) } - + fn read_async_variable_integer(stream: &mut S) -> impl Future> { async move { let mut integer = 0; let mut length = 0; - + for i in 0..4 { let byte = stream.read_u8().await?; length += 1; - + integer += (byte as usize & 0x7f) << (7 * i); - + if (byte & 0b1000_0000) == 0 { return Ok((integer, length)); } @@ -124,9 +139,9 @@ impl VariableInteger for u32 { if *self > 268_435_455 { return Err(SerializeError::VariableIntegerOverflow(*self as usize)); } - + let mut write = *self; - + for i in 0..4 { let mut byte = (write % 128) as u8; write /= 128; @@ -135,43 +150,43 @@ impl VariableInteger for u32 { } buf.put_u8(byte); if write == 0 { - return Ok(i+1); + return Ok(i + 1); } } Err(SerializeError::VariableIntegerOverflow(*self as usize)) } - + fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { let mut integer = 0; let mut length = 0; - + for i in 0..4 { if buf.is_empty() { return Err(DeserializeError::MalformedPacket); } length += 1; let byte = buf.get_u8(); - + integer += (byte as u32 & 0x7f) << (7 * i); - + if (byte & 0b1000_0000) == 0 { return Ok((integer, length)); } } Err(DeserializeError::MalformedPacket) } - + fn read_async_variable_integer(stream: &mut S) -> impl Future> { async move { let mut integer = 0; let mut length = 0; - + for i in 0..4 { let byte = stream.read_u8().await?; length += 1; - + integer += (byte as u32 & 0x7f) << (7 * i); - + if (byte & 0b1000_0000) == 0 { return Ok((integer, length)); } diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 62e2aff..2c4f07f 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -61,7 +61,7 @@ impl PacketRead for PubComp { impl PacketAsyncRead for PubComp where S: tokio::io::AsyncReadExt + Unpin { fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { - let (packet_identifier, id_read_bytes) = u16::async_read(stream).await?; + let packet_identifier = stream.read_u16().await?; if remaining_length == 2 { return Ok((Self { packet_identifier, @@ -77,13 +77,13 @@ impl PacketAsyncRead for PubComp where S: tokio::io::AsyncReadExt + Unpin let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; - assert_eq!(id_read_bytes + reason_code_read_bytes + properties_read_bytes, remaining_length); + assert_eq!(2 + reason_code_read_bytes + properties_read_bytes, remaining_length); Ok((Self { packet_identifier, reason_code, properties, - }, id_read_bytes + reason_code_read_bytes + properties_read_bytes)) + }, 2 + reason_code_read_bytes + properties_read_bytes)) } } } diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index 139c69e..c63063e 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -1,13 +1,16 @@ mod properties; pub use properties::PublishProperties; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncBufRead; + use bytes::{BufMut, Bytes}; use crate::error::PacketValidationError; use crate::util::constants::MAXIMUM_TOPIC_SIZE; -use super::mqtt_trait::{MqttRead, MqttWrite, PacketValidation, PacketRead, PacketWrite, WireLength}; +use super::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; use super::VariableInteger; use super::{ error::{DeserializeError, SerializeError}, QoS, @@ -34,11 +37,11 @@ pub struct Publish { pub publish_properties: PublishProperties, /// 3.3.3 PUBLISH Payload - pub payload: Bytes, + pub payload: Vec, } impl Publish { - pub fn new>(qos: QoS, retain: bool, topic: S, packet_identifier: Option, publish_properties: PublishProperties, payload: Bytes) -> Self { + pub fn new, P: Into>>(qos: QoS, retain: bool, topic: S, packet_identifier: Option, publish_properties: PublishProperties, payload: P) -> Self { Self { dup: false, qos, @@ -46,12 +49,12 @@ impl Publish { topic: topic.as_ref().into(), packet_identifier, publish_properties, - payload, + payload: payload.into(), } } - pub fn payload_to_vec(&self) -> Vec { - self.payload.to_vec() + pub fn payload(&self) -> &Vec { + &self.payload } } @@ -76,11 +79,53 @@ impl PacketRead for Publish { topic, packet_identifier, publish_properties, - payload: buf, + payload: buf.to_vec(), }) } } +impl PacketAsyncRead for Publish where S: tokio::io::AsyncRead + Unpin { + fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let mut total_read_bytes = 0; + let dup = flags & 0b1000 != 0; + let qos = QoS::from_u8((flags & 0b110) >> 1)?; + let retain = flags & 0b1 != 0; + + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_bytes; + let packet_identifier = if qos == QoS::AtMostOnce { None } else { + total_read_bytes += 2; + Some(stream.read_u16().await?) + }; + let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; + + let payload_len = remaining_length - total_read_bytes; + let mut payload = vec![0u8; payload_len]; + let payload_read_bytes = stream.read_exact(&mut payload).await?; + + assert_eq!(payload_read_bytes, payload_len); + + + Ok( + ( + Self { + dup, + qos, + retain, + topic, + packet_identifier, + publish_properties, + payload, + }, + total_read_bytes + payload_read_bytes + ) + ) + } + } +} + impl PacketWrite for Publish { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { self.topic.write(buf)?; diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 835bcec..8edadf4 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -7,9 +7,11 @@ pub use reason_code::PubRecReasonCode; use bytes::BufMut; +use tokio::io::AsyncReadExt; + use super::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, PacketAsyncRead, }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -55,6 +57,35 @@ impl PacketRead for PubRec { } } +impl PacketAsyncRead for PubRec where S: tokio::io::AsyncRead + Unpin{ + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; + if remaining_length == 2 { + return Ok((Self { + packet_identifier, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties::default(), + }, total_read_bytes)); + } + + let (reason_code, reason_code_read_bytes) = PubRecReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubRecProperties::async_read(stream).await?; + + total_read_bytes += reason_code_read_bytes + properties_read_bytes; + + Ok((Self { + packet_identifier, + properties, + reason_code, + }, total_read_bytes)) + + } + } +} + impl PacketWrite for PubRec { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index cf35314..974b026 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -1,4 +1,3 @@ - mod reason_code; pub use reason_code::PubRelReasonCode; @@ -8,8 +7,8 @@ pub use properties::PubRelProperties; use bytes::BufMut; use super::{ - error::{DeserializeError, ReadError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength} + error::{DeserializeError, ReadError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, }; #[derive(Debug, PartialEq, Eq, Clone, Hash)] @@ -54,11 +53,14 @@ impl PacketRead for PubRel { } } -impl PacketAsyncRead for PubRel where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for PubRel +where + S: tokio::io::AsyncReadExt + Unpin, +{ async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), ReadError> { let mut total_read_bytes = 0; - let (packet_identifier, read_bytes) = u16::async_read(stream).await?; - total_read_bytes += read_bytes; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; let res = if remaining_length == 2 { Self { packet_identifier, @@ -119,10 +121,11 @@ impl WireLength for PubRel { #[cfg(test)] mod tests { use crate::packets::{ - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, pubrel::{PubRel, PubRelProperties}, PropertyType, PubRelReasonCode, VariableInteger + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + pubrel::{PubRel, PubRelProperties}, + PropertyType, PubRelReasonCode, VariableInteger, }; use bytes::{BufMut, Bytes, BytesMut}; - #[test] fn test_wire_len() { @@ -147,16 +150,15 @@ mod tests { assert_eq!(3, buf.len()); } - #[test] fn test_wire_len2() { let mut buf = BytesMut::new(); - let prop = PubRelProperties{ + let prop = PubRelProperties { reason_string: Some("reason string, test 1-2-3.".into()), // 26 + 1 + 2 user_properties: vec![ ("This is the key".into(), "This is the value".into()), // 32 + 1 + 2 + 2 - ("Another thingy".into(), "The thingy".into()), // 24 + 1 + 2 + 2 + ("Another thingy".into(), "The thingy".into()), // 24 + 1 + 2 + 2 ], }; @@ -212,7 +214,7 @@ mod tests { assert_eq!(2, buf.len()); let mut stream: &[u8] = &*buf; - + let (pubrel, read_bytes) = PubRel::async_read(0, 2, &mut stream).await.unwrap(); assert_eq!(expected_pubrel, pubrel); @@ -224,13 +226,12 @@ mod tests { assert_eq!(3, buf.len()); let mut stream: &[u8] = &*buf; - + let (pubrel, read_bytes) = PubRel::async_read(0, 3, &mut stream).await.unwrap(); assert_eq!(read_bytes, 3); assert_eq!(expected_pubrel, pubrel); } - #[test] fn test_read_simple_pub_rel() { let stream = &[ @@ -323,7 +324,6 @@ mod tests { buf.extend(properties); - let mut stream = &*buf; // flags can be 0 because not used. // remaining_length must be at least 4 @@ -381,7 +381,7 @@ mod tests { assert_eq!(buf.to_vec(), result.to_vec()); assert_eq!(buf.len(), read_bytes); } - + #[test] fn no_reason_code_or_props() { let mut buf = BytesMut::new(); diff --git a/mqrstt/src/packets/reason_codes.rs b/mqrstt/src/packets/reason_codes.rs deleted file mode 100644 index 842383d..0000000 --- a/mqrstt/src/packets/reason_codes.rs +++ /dev/null @@ -1,254 +0,0 @@ - - - - -// #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -// pub enum ConnAckReasonCode { -// #[default] -// Success, - -// UnspecifiedError, -// MalformedPacket, -// ProtocolError, -// ImplementationSpecificError, -// UnsupportedProtocolVersion, -// ClientIdentifierNotValid, -// BadUsernameOrPassword, -// NotAuthorized, -// ServerUnavailable, -// ServerBusy, -// Banned, -// BadAuthenticationMethod, -// TopicNameInvalid, -// PacketTooLarge, -// QuotaExceeded, -// PayloadFormatInvalid, -// RetainNotSupported, -// QosNotSupported, -// UseAnotherServer, -// ServerMoved, -// ConnectionRateExceeded, -// } - -// impl MqttRead for ConnAckReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } -// let res = buf.get_u8(); - -// crate::packets::macros::reason_code_match!(@ ConnAckReasonCode, res, { -// Success, -// UnspecifiedError, -// MalformedPacket, -// ProtocolError, -// ImplementationSpecificError, -// UnsupportedProtocolVersion, -// ClientIdentifierNotValid, -// BadUsernameOrPassword, -// NotAuthorized, -// ServerUnavailable, -// ServerBusy, -// Banned, -// BadAuthenticationMethod, -// TopicNameInvalid, -// PacketTooLarge, -// QuotaExceeded, -// PayloadFormatInvalid, -// RetainNotSupported, -// QosNotSupported, -// UseAnotherServer, -// ServerMoved, -// ConnectionRateExceeded, -// } -> ()) -// // match buf.get_u8() { -// // 0x00 => Ok(ConnAckReasonCode::Success), -// // 0x80 => Ok(ConnAckReasonCode::UnspecifiedError), -// // 0x81 => Ok(ConnAckReasonCode::MalformedPacket), -// // 0x82 => Ok(ConnAckReasonCode::ProtocolError), -// // 0x83 => Ok(ConnAckReasonCode::ImplementationSpecificError), -// // 0x84 => Ok(ConnAckReasonCode::UnsupportedProtocolVersion), -// // 0x85 => Ok(ConnAckReasonCode::ClientIdentifierNotValid), -// // 0x86 => Ok(ConnAckReasonCode::BadUsernameOrPassword), -// // 0x87 => Ok(ConnAckReasonCode::NotAuthorized), -// // 0x88 => Ok(ConnAckReasonCode::ServerUnavailable), -// // 0x89 => Ok(ConnAckReasonCode::ServerBusy), -// // 0x8A => Ok(ConnAckReasonCode::Banned), -// // 0x8C => Ok(ConnAckReasonCode::BadAuthenticationMethod), -// // 0x90 => Ok(ConnAckReasonCode::TopicNameInvalid), -// // 0x95 => Ok(ConnAckReasonCode::PacketTooLarge), -// // 0x97 => Ok(ConnAckReasonCode::QuotaExceeded), -// // 0x99 => Ok(ConnAckReasonCode::PayloadFormatInvalid), -// // 0x9A => Ok(ConnAckReasonCode::RetainNotSupported), -// // 0x9B => Ok(ConnAckReasonCode::QosNotSupported), -// // 0x9C => Ok(ConnAckReasonCode::UseAnotherServer), -// // 0x9D => Ok(ConnAckReasonCode::ServerMoved), -// // 0x9F => Ok(ConnAckReasonCode::ConnectionRateExceeded), -// // t => Err(DeserializeError::UnknownProperty(t)), -// // } -// } -// } - -// impl MqttWrite for ConnAckReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// ConnAckReasonCode::Success => 0x00, -// ConnAckReasonCode::UnspecifiedError => 0x80, -// ConnAckReasonCode::MalformedPacket => 0x81, -// ConnAckReasonCode::ProtocolError => 0x82, -// ConnAckReasonCode::ImplementationSpecificError => 0x83, -// ConnAckReasonCode::UnsupportedProtocolVersion => 0x84, -// ConnAckReasonCode::ClientIdentifierNotValid => 0x85, -// ConnAckReasonCode::BadUsernameOrPassword => 0x86, -// ConnAckReasonCode::NotAuthorized => 0x87, -// ConnAckReasonCode::ServerUnavailable => 0x88, -// ConnAckReasonCode::ServerBusy => 0x89, -// ConnAckReasonCode::Banned => 0x8A, -// ConnAckReasonCode::BadAuthenticationMethod => 0x8C, -// ConnAckReasonCode::TopicNameInvalid => 0x90, -// ConnAckReasonCode::PacketTooLarge => 0x95, -// ConnAckReasonCode::QuotaExceeded => 0x97, -// ConnAckReasonCode::PayloadFormatInvalid => 0x99, -// ConnAckReasonCode::RetainNotSupported => 0x9A, -// ConnAckReasonCode::QosNotSupported => 0x9B, -// ConnAckReasonCode::UseAnotherServer => 0x9C, -// ConnAckReasonCode::ServerMoved => 0x9D, -// ConnAckReasonCode::ConnectionRateExceeded => 0x9F, -// }; - -// buf.put_u8(val); - -// Ok(()) -// } -// } - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum AuthReasonCode { -// Success, -// ContinueAuthentication, -// ReAuthenticate, -// } - -// impl MqttRead for AuthReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(AuthReasonCode::Success), -// 0x18 => Ok(AuthReasonCode::ContinueAuthentication), -// 0x19 => Ok(AuthReasonCode::ReAuthenticate), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttWrite for AuthReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// AuthReasonCode::Success => 0x00, -// AuthReasonCode::ContinueAuthentication => 0x18, -// AuthReasonCode::ReAuthenticate => 0x19, -// }; - -// buf.put_u8(val); - -// Ok(()) -// } -// } - - - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum PubCompReasonCode { -// Success, -// PacketIdentifierNotFound, -// } - -// impl MqttRead for PubCompReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(PubCompReasonCode::Success), -// 0x92 => Ok(PubCompReasonCode::PacketIdentifierNotFound), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } -// impl MqttWrite for PubCompReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// PubCompReasonCode::Success => 0x00, -// PubCompReasonCode::PacketIdentifierNotFound => 0x92, -// }; - -// buf.put_u8(val); -// Ok(()) -// } -// } - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum SubAckReasonCode { -// GrantedQoS0, -// GrantedQoS1, -// GrantedQoS2, -// UnspecifiedError, -// ImplementationSpecificError, -// NotAuthorized, -// TopicFilterInvalid, -// PacketIdentifierInUse, -// QuotaExceeded, -// SharedSubscriptionsNotSupported, -// SubscriptionIdentifiersNotSupported, -// WildcardSubscriptionsNotSupported, -// } - -// impl MqttRead for SubAckReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(SubAckReasonCode::GrantedQoS0), -// 0x01 => Ok(SubAckReasonCode::GrantedQoS1), -// 0x02 => Ok(SubAckReasonCode::GrantedQoS2), -// 0x80 => Ok(SubAckReasonCode::UnspecifiedError), -// 0x83 => Ok(SubAckReasonCode::ImplementationSpecificError), -// 0x87 => Ok(SubAckReasonCode::NotAuthorized), -// 0x8F => Ok(SubAckReasonCode::TopicFilterInvalid), -// 0x91 => Ok(SubAckReasonCode::PacketIdentifierInUse), -// 0x97 => Ok(SubAckReasonCode::QuotaExceeded), -// 0x9E => Ok(SubAckReasonCode::SharedSubscriptionsNotSupported), -// 0xA1 => Ok(SubAckReasonCode::SubscriptionIdentifiersNotSupported), -// 0xA2 => Ok(SubAckReasonCode::WildcardSubscriptionsNotSupported), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttWrite for SubAckReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// SubAckReasonCode::GrantedQoS0 => 0x00, -// SubAckReasonCode::GrantedQoS1 => 0x01, -// SubAckReasonCode::GrantedQoS2 => 0x02, -// SubAckReasonCode::UnspecifiedError => 0x80, -// SubAckReasonCode::ImplementationSpecificError => 0x83, -// SubAckReasonCode::NotAuthorized => 0x87, -// SubAckReasonCode::TopicFilterInvalid => 0x8F, -// SubAckReasonCode::PacketIdentifierInUse => 0x91, -// SubAckReasonCode::QuotaExceeded => 0x97, -// SubAckReasonCode::SharedSubscriptionsNotSupported => 0x9E, -// SubAckReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, -// SubAckReasonCode::WildcardSubscriptionsNotSupported => 0xA2, -// }; - -// buf.put_u8(val); -// Ok(()) -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index b4e34a3..718909a 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -7,7 +7,7 @@ pub use reason_code::SubAckReasonCode; use bytes::BufMut; use super::{ - error::{SerializeError}, + error::SerializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, }; @@ -53,9 +53,9 @@ impl PacketAsyncRead for SubAck where S: tokio::io::AsyncReadExt + Unpin { fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let mut total_read_bytes = 0; - let (packet_identifier, id_read_bytes) = u16::async_read(stream).await?; + let packet_identifier = stream.read_u16().await?; let (properties, proproperties_read_bytes) = SubAckProperties::async_read(stream).await?; - total_read_bytes += id_read_bytes + proproperties_read_bytes; + total_read_bytes += 2 + proproperties_read_bytes; let mut reason_codes = vec![]; loop { let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; diff --git a/mqrstt/src/packets/subscribe.rs b/mqrstt/src/packets/subscribe/mod.rs similarity index 79% rename from mqrstt/src/packets/subscribe.rs rename to mqrstt/src/packets/subscribe/mod.rs index 3263205..de9b082 100644 --- a/mqrstt/src/packets/subscribe.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -1,8 +1,14 @@ +mod properties; +use std::ops::Sub; + +pub use properties::SubscribeProperties; +use tokio::io::AsyncReadExt; + use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use super::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketValidation, PacketWrite, WireLength}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}, PacketType, PropertyType, QoS, VariableInteger, }; use bytes::{Buf, BufMut}; @@ -48,6 +54,36 @@ impl PacketRead for Subscribe { } } +impl PacketAsyncRead for Subscribe where S: tokio::io::AsyncReadExt + Unpin { + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, proproperties_read_bytes) = SubscribeProperties::async_read(stream).await?; + total_read_bytes += 2 + proproperties_read_bytes; + + let mut topics = vec![]; + loop { + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + let (options, options_read_bytes) = SubscriptionOptions::async_read(stream).await?; + total_read_bytes += topic_read_bytes + options_read_bytes; + topics.push((topic, options)); + + if remaining_length >= total_read_bytes { + break; + } + } + + Ok((Self { + packet_identifier, + properties, + topics, + }, total_read_bytes)) + } + } +} + + impl PacketWrite for Subscribe { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -88,85 +124,6 @@ impl PacketValidation for Subscribe { } } -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct SubscribeProperties { - /// 3.8.2.1.2 Subscription Identifier - /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. - pub subscription_id: Option, - - /// 3.8.2.1.3 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(Box, Box)>, -} - -impl MqttRead for SubscribeProperties { - fn read(buf: &mut bytes::Bytes) -> Result { - let (len, _) = VariableInteger::read_variable_integer(buf)?; - - let mut properties = SubscribeProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); - } - - let mut properties_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut properties_data)? { - PropertyType::SubscriptionIdentifier => { - if properties.subscription_id.is_none() { - let (subscription_id, _) = VariableInteger::read_variable_integer(&mut properties_data)?; - - properties.subscription_id = Some(subscription_id); - } else { - return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); - } - } - PropertyType::UserProperty => { - properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Subscribe)), - } - - if properties_data.is_empty() { - break; - } - } - Ok(properties) - } -} - -impl MqttWrite for SubscribeProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - self.wire_len().write_variable_integer(buf)?; - if let Some(sub_id) = self.subscription_id { - PropertyType::SubscriptionIdentifier.write(buf)?; - sub_id.write_variable_integer(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - Ok(()) - } -} - -impl WireLength for SubscribeProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(sub_id) = self.subscription_id { - len += 1 + sub_id.variable_integer_len(); - } - for (key, value) in &self.user_properties { - len += 1 + key.wire_len() + value.wire_len(); - } - len - } -} - #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct SubscriptionOptions { pub retain_handling: RetainHandling, @@ -210,6 +167,29 @@ impl MqttRead for SubscriptionOptions { } } +impl MqttAsyncRead for SubscriptionOptions where S: tokio::io::AsyncRead + Unpin { + fn async_read(stream: &mut S) -> impl std::future::Future> { + async move { + let byte = stream.read_u8().await?; + + let retain_handling_part = (byte & 0b00110000) >> 4; + let retain_as_publish_part = (byte & 0b00001000) >> 3; + let no_local_part = (byte & 0b00000100) >> 2; + let qos_part = byte & 0b00000011; + + let options = Self { + retain_handling: RetainHandling::from_u8(retain_handling_part)?, + retain_as_publish: retain_as_publish_part != 0, + no_local: no_local_part != 0, + qos: QoS::from_u8(qos_part)?, + }; + + Ok((options, 1)) + } + } +} + + impl MqttWrite for SubscriptionOptions { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { let byte = (self.retain_handling.into_u8() << 4) | ((self.retain_as_publish as u8) << 3) | ((self.no_local as u8) << 2) | self.qos.into_u8(); diff --git a/mqrstt/src/packets/subscribe/properties.rs b/mqrstt/src/packets/subscribe/properties.rs index e69de29..4ef6e36 100644 --- a/mqrstt/src/packets/subscribe/properties.rs +++ b/mqrstt/src/packets/subscribe/properties.rs @@ -0,0 +1,94 @@ + + +use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; + +use crate::packets::{ + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, +}; + +crate::packets::macros::define_properties!( + SubscribeProperties, + SubscriptionIdentifier, + UserProperty +); + +// #[derive(Debug, Default, PartialEq, Eq, Clone)] +// pub struct SubscribeProperties { +// /// 3.8.2.1.2 Subscription Identifier +// /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. +// pub subscription_id: Option, + +// /// 3.8.2.1.3 User Property +// /// 38 (0x26) Byte, Identifier of the User Property. +// pub user_properties: Vec<(Box, Box)>, +// } + +impl MqttRead for SubscribeProperties { + fn read(buf: &mut bytes::Bytes) -> Result { + let (len, _) = VariableInteger::read_variable_integer(buf)?; + + let mut properties = SubscribeProperties::default(); + + if len == 0 { + return Ok(properties); + } else if buf.len() < len { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); + } + + let mut properties_data = buf.split_to(len); + + loop { + match PropertyType::read(&mut properties_data)? { + PropertyType::SubscriptionIdentifier => { + if properties.subscription_identifier.is_none() { + let (subscription_id, _) = VariableInteger::read_variable_integer(&mut properties_data)?; + + properties.subscription_identifier = Some(subscription_id); + } else { + return Err(DeserializeError::DuplicateProperty(PropertyType::SubscriptionIdentifier)); + } + } + PropertyType::UserProperty => { + properties.user_properties.push((Box::::read(&mut properties_data)?, Box::::read(&mut properties_data)?)); + } + e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Subscribe)), + } + + if properties_data.is_empty() { + break; + } + } + Ok(properties) + } +} + +impl MqttWrite for SubscribeProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { + self.wire_len().write_variable_integer(buf)?; + if let Some(sub_id) = self.subscription_identifier { + PropertyType::SubscriptionIdentifier.write(buf)?; + sub_id.write_variable_integer(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + Ok(()) + } +} + +// impl WireLength for SubscribeProperties { +// fn wire_len(&self) -> usize { +// let mut len = 0; +// if let Some(sub_id) = self.subscription_identifier { +// len += 1 + sub_id.variable_integer_len(); +// } +// for (key, value) in &self.user_properties { +// len += 1 + key.wire_len() + value.wire_len(); +// } +// len +// } +// } \ No newline at end of file diff --git a/mqrstt/src/packets/subscribe/reason_code.rs b/mqrstt/src/packets/subscribe/reason_code.rs deleted file mode 100644 index e69de29..0000000 diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs index 6e3bb25..dd4a6bf 100644 --- a/mqrstt/src/packets/unsuback/mod.rs +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -4,11 +4,15 @@ pub use properties::UnsubAckProperties; mod reason_code; pub use reason_code::UnsubAckReasonCode; +use crate::packets::mqtt_trait::MqttAsyncRead; use bytes::BufMut; -use super::error::{SerializeError}; +use tokio::io::AsyncReadExt; + +use super::error::SerializeError; use super::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite}; +use super::PacketAsyncRead; #[derive(Debug, Default, PartialEq, Eq, Clone)] pub struct UnsubAck { @@ -40,6 +44,43 @@ impl PacketRead for UnsubAck { } } +impl PacketAsyncRead for UnsubAck +where + S: tokio::io::AsyncRead + Unpin, +{ + fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { + async move { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; + + let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; + + let mut reason_codes = vec![]; + loop { + let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + + reason_codes.push(reason_code); + + if total_read_bytes >= remaining_length { + break; + } + } + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) + } + } +} + impl PacketWrite for UnsubAck { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { buf.put_u16(self.packet_identifier); diff --git a/mqrstt/src/packets/unsubscribe/mod.rs b/mqrstt/src/packets/unsubscribe/mod.rs index e9c8d3d..23eb72a 100644 --- a/mqrstt/src/packets/unsubscribe/mod.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -57,9 +57,9 @@ impl PacketAsyncRead for Unsubscribe where S: tokio::io::AsyncReadExt + Un fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let mut total_read_bytes = 0; - let (packet_identifier, id_read_bytes) = u16::async_read(stream).await?; + let packet_identifier = stream.read_u16().await?; let (properties, properties_read_bytes) = UnsubscribeProperties::async_read(stream).await?; - total_read_bytes += id_read_bytes + properties_read_bytes; + total_read_bytes += 2 + properties_read_bytes; let mut topics = vec![]; loop { diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index 08c39b9..2af3b0f 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -1,12 +1,125 @@ -use bytes::Bytes; - use rstest::*; -use crate::packets::{ - DisconnectReasonCode, ConnAck, Disconnect, DisconnectProperties, Packet, PubAck, PubAckProperties, PubAckReasonCode, Publish, PublishProperties, QoS, Subscribe, Subscription, Unsubscribe -}; +use crate::packets::*; + +pub fn connack_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[ + 0x20, 0x13, 0x01, 0x00, 0x10, 0x27, 0x00, 0x10, 0x00, 0x00, 0x25, 0x01, 0x2a, 0x01, 0x29, 0x01, 0x22, 0xff, 0xff, 0x28, 0x01, + ]; + + let expected = ConnAck { + connack_flags: ConnAckFlags { session_present: true }, + reason_code: ConnAckReasonCode::Success, + connack_properties: ConnAckProperties { + session_expiry_interval: None, + receive_maximum: None, + maximum_qos: None, + retain_available: Some(true), + maximum_packet_size: Some(1048576), + assigned_client_id: None, + topic_alias_maximum: Some(65535), + reason_string: None, + user_properties: vec![], + wildcards_available: Some(true), + subscription_ids_available: Some(true), + shared_subscription_available: Some(true), + server_keep_alive: None, + response_info: None, + server_reference: None, + authentication_method: None, + authentication_data: None, + }, + }; + + (packet, Packet::ConnAck(expected)) +} + +pub fn disconnect_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0xe0, 0x02, 0x8e, 0x00]; + + let expected = Disconnect { + reason_code: DisconnectReasonCode::SessionTakenOver, + properties: DisconnectProperties { + session_expiry_interval: None, + reason_string: None, + user_properties: vec![], + server_reference: None, + }, + }; + + (packet, Packet::Disconnect(expected)) +} + +pub fn ping_req_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0xc0, 0x00]; + + (packet, Packet::PingReq) +} + +pub fn ping_resp_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0xd0, 0x00]; + + (packet, Packet::PingResp) +} +pub fn publish_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[ + 0x35, 0x24, 0x00, 0x14, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x31, 0x32, 0x33, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x2f, 0x62, 0x6c, 0x61, 0x62, 0x6c, 0x61, 0x35, 0xd3, 0x0b, 0x01, 0x01, 0x09, 0x00, 0x04, + 0x31, 0x32, 0x31, 0x32, 0x0b, 0x01, + ]; + + let expected = Publish { + dup: false, + qos: QoS::ExactlyOnce, + retain: true, + topic: "test/123/test/blabla".into(), + packet_identifier: Some(13779), + publish_properties: PublishProperties { + payload_format_indicator: Some(1), + message_expiry_interval: None, + topic_alias: None, + response_topic: None, + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], + user_properties: vec![], + content_type: None, + }, + payload: b"".to_vec(), + }; + + (packet, Packet::Publish(expected)) +} + +pub fn pubrel_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0x62, 0x02, 0x35, 0xd3]; + + let expected = PubRel { + packet_identifier: 13779, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: None, + user_properties: vec![], + }, + }; + + (packet, Packet::PubRel(expected)) +} + +pub fn pubrel_smallest_case() -> (&'static [u8], Packet) { + let packet: &'static [u8] = &[0x62, 0x02, 0x35, 0xd3]; + + let expected = PubRel { + packet_identifier: 13779, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: None, + user_properties: vec![], + }, + }; + + (packet, Packet::PubRel(expected)) +} -fn publish_packet_1() -> Packet { +pub fn publish_packet_1() -> Packet { Packet::Publish(Publish { dup: false, qos: QoS::ExactlyOnce, @@ -23,10 +136,10 @@ fn publish_packet_1() -> Packet { user_properties: vec![], content_type: None, }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), }) } -fn publish_packet_2() -> Packet { +pub fn publish_packet_2() -> Packet { Packet::Publish(Publish { dup: true, qos: QoS::ExactlyOnce, @@ -43,10 +156,10 @@ fn publish_packet_2() -> Packet { user_properties: vec![], content_type: None, }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), }) } -fn publish_packet_3() -> Packet { +pub fn publish_packet_3() -> Packet { Packet::Publish(Publish { dup: true, qos: QoS::AtLeastOnce, @@ -63,10 +176,10 @@ fn publish_packet_3() -> Packet { user_properties: vec![("blabla".into(), "another blabla".into())], content_type: None, }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), }) } -fn publish_packet_4() -> Packet { +pub fn publish_packet_4() -> Packet { Packet::Publish(Publish { dup: true, qos: QoS::AtLeastOnce, @@ -83,7 +196,7 @@ fn publish_packet_4() -> Packet { user_properties: vec![], content_type: Some("Garbage".into()), }, - payload: Bytes::from_static(b""), + payload: b"".to_vec(), // payload: Bytes::from_iter(b"abcdefg".repeat(500)), }) } @@ -116,7 +229,7 @@ pub fn create_publish_packet(qos: QoS, dup: bool, retain: bool, packet_identifie user_properties: vec![], content_type: Some("Garbage".into()), }, - payload: Bytes::from_iter(b"testabcbba==asdasdasdasdasd".repeat(500)), + payload: b"testabcbba==asdasdasdasdasd".repeat(500).to_vec(), }) } From 7f939c337a22b00bae6ad1ba07f1dff092c99c13 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Wed, 20 Nov 2024 23:50:31 +0100 Subject: [PATCH 10/45] use codecov secret --- .github/workflows/rust.yml | 3 ++- deny.toml | 10 ++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 487fc0f..4ce4a31 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -52,7 +52,7 @@ jobs: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v3 - - uses: EmbarkStudios/cargo-deny-action@v1 + - uses: EmbarkStudios/cargo-deny-action@v2 coverage: name: Coverage @@ -84,4 +84,5 @@ jobs: - name: Upload coverage report uses: codecov/codecov-action@v3 with: + token: ${{ secrets.CODECOV_TOKEN }} files: ./lcov.txt diff --git a/deny.toml b/deny.toml index 7e10246..ee3595b 100644 --- a/deny.toml +++ b/deny.toml @@ -1,14 +1,8 @@ -[advisories] -vulnerability = "deny" -unsound = "deny" -unmaintained = "deny" + [licenses] -unlicensed = "deny" -allow-osi-fsf-free = "neither" -copyleft = "deny" confidence-threshold = 0.95 -allow = ["MPL-2.0", "Apache-2.0", "MIT", "BSD-3-Clause", "ISC"] +allow = ["MPL-2.0", "Apache-2.0", "MIT", "BSD-3-Clause", "ISC", "Unicode-3.0"] exceptions = [ { allow = ["Unicode-DFS-2016"], name = "unicode-ident" }, From 8b5486f68566e9255c745221b8c389cdb4007192 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 21 Nov 2024 00:00:35 +0100 Subject: [PATCH 11/45] remove commented code --- mqrstt/src/packets/auth/properties.rs | 51 +------- mqrstt/src/packets/connack/properties.rs | 78 +----------- .../src/packets/connect/connect_properties.rs | 44 +------ .../packets/connect/last_will_properties.rs | 48 +------ mqrstt/src/packets/disconnect/properties.rs | 34 +---- mqrstt/src/packets/disconnect/reason_code.rs | 120 +----------------- mqrstt/src/packets/macros/mod.rs | 11 +- mqrstt/src/packets/mod.rs | 18 +-- mqrstt/src/packets/mqtt_trait/mod.rs | 78 ++---------- mqrstt/src/packets/primitive/fixed_header.rs | 22 ++-- mqrstt/src/packets/puback/mod.rs | 71 ++++------- mqrstt/src/packets/puback/reason_code.rs | 55 -------- mqrstt/src/packets/pubcomp/properties.rs | 29 ----- mqrstt/src/packets/publish/mod.rs | 48 +++---- mqrstt/src/packets/publish/properties.rs | 73 +---------- mqrstt/src/packets/pubrec/properties.rs | 29 +---- mqrstt/src/packets/pubrec/reason_code.rs | 59 +-------- mqrstt/src/packets/pubrel/properties.rs | 65 +--------- mqrstt/src/packets/pubrel/reason_code.rs | 50 +------- mqrstt/src/packets/suback/properties.rs | 33 +---- mqrstt/src/packets/suback/reason_code.rs | 46 ------- mqrstt/src/packets/subscribe/mod.rs | 30 +++-- mqrstt/src/packets/subscribe/properties.rs | 34 +---- mqrstt/src/packets/unsuback/properties.rs | 25 +--- mqrstt/src/packets/unsuback/reason_code.rs | 48 ------- mqrstt/src/packets/unsubscribe/properties.rs | 16 --- 26 files changed, 130 insertions(+), 1085 deletions(-) diff --git a/mqrstt/src/packets/auth/properties.rs b/mqrstt/src/packets/auth/properties.rs index 259b11d..1ed1a60 100644 --- a/mqrstt/src/packets/auth/properties.rs +++ b/mqrstt/src/packets/auth/properties.rs @@ -2,35 +2,11 @@ use bytes::Bytes; use crate::packets::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite,WireLength}, PacketType, PropertyType, VariableInteger, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!( - AuthProperties, - AuthenticationMethod, - AuthenticationData, - ReasonString, - UserProperty -); - -// #[derive(Debug, Default, PartialEq, Eq, Clone)] -// pub struct AuthProperties { -// /// 3.15.2.2.2 Authentication Method -// /// 21 (0x15) Byte, Identifier of the Authentication Method. -// pub authentication_method: Option>, - -// /// 3.15.2.2.3 Authentication Data -// /// 22 (0x16) Byte, Identifier of the Authentication Data -// pub authentication_data: Vec, - -// /// 3.15.2.2.4 Reason String -// /// 31 (0x1F) Byte, Identifier of the Reason String -// pub reason_string: Option>, - -// /// 3.15.2.2.5 User Property -// /// 38 (0x26) Byte, Identifier of the User Property. -// pub user_properties: Vec<(Box, Box)>, -// } +crate::packets::macros::define_properties!(AuthProperties, AuthenticationMethod, AuthenticationData, ReasonString, UserProperty); impl MqttRead for AuthProperties { fn read(buf: &mut Bytes) -> Result { @@ -106,24 +82,3 @@ impl MqttWrite for AuthProperties { Ok(()) } } - -// impl WireLength for AuthProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(authentication_method) = &self.authentication_method { -// len += 1 + authentication_method.wire_len(); -// } -// if let Some(authentication_data) = self.authentication_data { -// if !authentication_data.is_empty() && self.authentication_method.is_some() { -// len += 1 + authentication_data.wire_len(); -// } -// } -// if let Some(reason_string) = &self.reason_string { -// len += 1 + reason_string.wire_len(); -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/connack/properties.rs b/mqrstt/src/packets/connack/properties.rs index 91b1cf4..979b978 100644 --- a/mqrstt/src/packets/connack/properties.rs +++ b/mqrstt/src/packets/connack/properties.rs @@ -5,7 +5,8 @@ use crate::packets::{ }; use bytes::BufMut; -crate::packets::macros::define_properties!(ConnAckProperties, +crate::packets::macros::define_properties!( + ConnAckProperties, SessionExpiryInterval, ReceiveMaximum, MaximumQos, @@ -25,79 +26,6 @@ crate::packets::macros::define_properties!(ConnAckProperties, AuthenticationData ); -// #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -// pub struct ConnAckProperties { -// /// 3.2.2.3.2 Session Expiry Interval -// /// 17 (0x11) Byte Identifier of the Session Expiry Interval -// pub session_expiry_interval: Option, - -// /// 3.2.2.3.3 Receive Maximum -// /// 33 (0x21) Byte, Identifier of the Receive Maximum -// pub receive_maximum: Option, - -// /// 3.2.2.3.4 Maximum QoS -// /// 36 (0x24) Byte, Identifier of the Maximum QoS. -// pub maximum_qos: Option, - -// /// 3.2.2.3.5 Retain Available -// /// 37 (0x25) Byte, Identifier of Retain Available. -// pub retain_available: Option, - -// /// 3.2.2.3.6 Maximum Packet Size -// /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. -// pub maximum_packet_size: Option, - -// /// 3.2.2.3.7 Assigned Client Identifier -// /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. -// pub assigned_client_id: Option>, - -// /// 3.2.2.3.8 Topic Alias Maximum -// /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. -// pub topic_alias_maximum: Option, - -// /// 3.2.2.3.9 Reason String -// /// 31 (0x1F) Byte Identifier of the Reason String. -// pub reason_string: Option>, - -// /// 3.2.2.3.10 User Property -// /// 38 (0x26) Byte, Identifier of User Property. -// pub user_properties: Vec<(Box, Box)>, - -// /// 3.2.2.3.11 Wildcard Subscription Available -// /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. -// pub wildcards_available: Option, - -// /// 3.2.2.3.12 Subscription Identifiers Available -// /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. -// pub subscription_ids_available: Option, - -// /// 3.2.2.3.13 Shared Subscription Available -// /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. -// pub shared_subscription_available: Option, - -// /// 3.2.2.3.14 Server Keep Alive -// /// 19 (0x13) Byte, Identifier of the Server Keep Alive -// pub server_keep_alive: Option, - -// /// 3.2.2.3.15 Response Information -// /// 26 (0x1A) Byte, Identifier of the Response Information. -// pub response_info: Option>, - -// /// 3.2.2.3.16 Server Reference -// /// 28 (0x1C) Byte, Identifier of the Server Reference -// pub server_reference: Option>, - -// /// 3.2.2.3.17 Authentication Method -// /// 21 (0x15) Byte, Identifier of the Authentication Method -// pub authentication_method: Option>, - -// /// 3.2.2.3.18 Authentication Data -// /// 22 (0x16) Byte, Identifier of the Authentication Data -// // There is a small inconsistency here with authentication_data in the connect packet. -// // This is Option while that type uses just Bytes. -// pub authentication_data: Option>, -// } - impl MqttRead for ConnAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; @@ -323,4 +251,4 @@ impl MqttWrite for ConnAckProperties { Ok(()) } -} \ No newline at end of file +} diff --git a/mqrstt/src/packets/connect/connect_properties.rs b/mqrstt/src/packets/connect/connect_properties.rs index ca25dc5..9a22591 100644 --- a/mqrstt/src/packets/connect/connect_properties.rs +++ b/mqrstt/src/packets/connect/connect_properties.rs @@ -1,6 +1,5 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; - use crate::packets::VariableInteger; use crate::packets::{ error::{DeserializeError, SerializeError}, @@ -8,7 +7,6 @@ use crate::packets::{ PacketType, PropertyType, WireLength, }; - // /// Connect Properties // /// // /// The wire representation starts with the length of all properties after which @@ -17,7 +15,8 @@ use crate::packets::{ // /// 3.1.2.11.1 Property Length // /// The length of the Properties in the CONNECT packet Variable Header encoded as a Variable Byte Integer. // /// Followed by all possible connect properties: -crate::packets::macros::define_properties!(ConnectProperties, +crate::packets::macros::define_properties!( + ConnectProperties, SessionExpiryInterval, ReceiveMaximum, MaximumPacketSize, @@ -29,45 +28,6 @@ crate::packets::macros::define_properties!(ConnectProperties, AuthenticationData ); -// #[derive(Debug, Default, Clone, PartialEq, Eq)] -// pub struct ConnectProperties { -// /// 3.1.2.11.2 Session Expiry Interval -// /// 17 (0x11) Byte Identifier of the Session Expiry Interval -// pub session_expiry_interval: Option, - -// /// 3.1.2.11.3 Receive Maximum -// /// 33 (0x21) Byte, Identifier of the Receive Maximum -// pub receive_maximum: Option, - -// /// 3.1.2.11.4 Maximum Packet Size -// /// 39 (0x27) Byte, Identifier of the Maximum Packet Size -// pub maximum_packet_size: Option, - -// /// 3.1.2.11.5 Topic Alias Maximum -// /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum -// pub topic_alias_maximum: Option, - -// /// 3.1.2.11.6 Request Response Information -// /// 25 (0x19) Byte, Identifier of the Request Response Information -// pub request_response_information: Option, - -// /// 3.1.2.11.7 Request Problem Information -// /// 23 (0x17) Byte, Identifier of the Request Problem Information -// pub request_problem_information: Option, - -// /// 3.1.2.11.8 User Property -// /// 38 (0x26) Byte, Identifier of the User Property -// pub user_properties: Vec<(Box, Box)>, - -// /// 3.1.2.11.9 Authentication Method -// /// 21 (0x15) Byte, Identifier of the Authentication Method -// pub authentication_method: Option>, - -// /// 3.1.2.11.10 Authentication Data -// /// 22 (0x16) Byte, Identifier of the Authentication Data -// pub authentication_data: Bytes, -// } - impl MqttWrite for ConnectProperties { fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { self.wire_len().write_variable_integer(buf)?; diff --git a/mqrstt/src/packets/connect/last_will_properties.rs b/mqrstt/src/packets/connect/last_will_properties.rs index 8051453..9e7db09 100644 --- a/mqrstt/src/packets/connect/last_will_properties.rs +++ b/mqrstt/src/packets/connect/last_will_properties.rs @@ -1,6 +1,5 @@ use bytes::{BufMut, Bytes, BytesMut}; - use crate::packets::VariableInteger; use crate::packets::{ error::{DeserializeError, SerializeError}, @@ -8,7 +7,8 @@ use crate::packets::{ PacketType, PropertyType, WireLength, }; -crate::packets::macros::define_properties!(LastWillProperties, +crate::packets::macros::define_properties!( + LastWillProperties, WillDelayInterval, PayloadFormatIndicator, MessageExpiryInterval, @@ -18,25 +18,6 @@ crate::packets::macros::define_properties!(LastWillProperties, UserProperty ); - -// #[derive(Debug, Default, Clone, PartialEq, Eq)] -// pub struct LastWillProperties { -// /// 3.1.3.2.2 Will Delay Interval -// will_delay_interval: Option, -// /// 3.1.3.2.3 Payload Format Indicator -// payload_format_indicator: Option, -// /// 3.1.3.2.4 Message Expiry Interval -// message_expiry_interval: Option, -// /// 3.1.3.2.5 Content Type -// content_type: Option>, -// /// 3.1.3.2.6 Response Topic -// response_topic: Option>, -// /// 3.1.3.2.7 Correlation Data -// correlation_data: Option, -// /// 3.1.3.2.8 User Property -// user_properties: Vec<(Box, Box)>, -// } - impl MqttRead for LastWillProperties { fn read(buf: &mut Bytes) -> Result { let (len, _) = VariableInteger::read_variable_integer(buf)?; @@ -139,28 +120,3 @@ impl MqttWrite for LastWillProperties { Ok(()) } } - -// impl WireLength for LastWillProperties { -// fn wire_len(&self) -> usize { -// let mut len: usize = 0; - -// if self.will_delay_interval.is_some() { -// len += 5; -// } -// if self.payload_format_indicator.is_some() { -// len += 2; -// } -// if self.message_expiry_interval.is_some() { -// len += 5; -// } -// // +1 for the property type -// len += self.content_type.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); -// len += self.response_topic.as_ref().map_or_else(|| 0, |s| s.wire_len() + 1); -// len += self.correlation_data.as_ref().map_or_else(|| 0, |b| b.wire_len() + 1); -// for (key, value) in &self.user_properties { -// len += key.wire_len() + value.wire_len() + 1; -// } - -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/disconnect/properties.rs b/mqrstt/src/packets/disconnect/properties.rs index b750082..1c5d2d0 100644 --- a/mqrstt/src/packets/disconnect/properties.rs +++ b/mqrstt/src/packets/disconnect/properties.rs @@ -2,24 +2,11 @@ use bytes::BufMut; use crate::packets::{ error::DeserializeError, - mqtt_trait::{ MqttRead, MqttWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!(DisconnectProperties, - SessionExpiryInterval, - ReasonString, - UserProperty, - ServerReference -); - -// #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -// pub struct DisconnectProperties { -// pub session_expiry_interval: Option, -// pub reason_string: Option>, -// pub user_properties: Vec<(Box, Box)>, -// pub server_reference: Option>, -// } +crate::packets::macros::define_properties!(DisconnectProperties, SessionExpiryInterval, ReasonString, UserProperty, ServerReference); impl MqttRead for DisconnectProperties { fn read(buf: &mut bytes::Bytes) -> Result { @@ -91,20 +78,3 @@ impl MqttWrite for DisconnectProperties { Ok(()) } } - -// impl WireLength for DisconnectProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if self.session_expiry_interval.is_some() { -// len += 4 + 1; -// } -// if let Some(reason_string) = &self.reason_string { -// len += reason_string.wire_len() + 1; -// } -// len += self.user_properties.iter().fold(0, |acc, (k, v)| acc + k.wire_len() + v.wire_len() + 1); -// if let Some(server_refrence) = &self.server_reference { -// len += server_refrence.wire_len() + 1; -// } -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/disconnect/reason_code.rs b/mqrstt/src/packets/disconnect/reason_code.rs index 6c993de..e8b1c00 100644 --- a/mqrstt/src/packets/disconnect/reason_code.rs +++ b/mqrstt/src/packets/disconnect/reason_code.rs @@ -1,4 +1,5 @@ -crate::packets::macros::reason_code!(DisconnectReasonCode, +crate::packets::macros::reason_code!( + DisconnectReasonCode, NormalDisconnection, DisconnectWithWillMessage, UnspecifiedError, @@ -29,120 +30,3 @@ crate::packets::macros::reason_code!(DisconnectReasonCode, SubscriptionIdentifiersNotSupported, WildcardSubscriptionsNotSupported ); - -// #[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum DisconnectReasonCode { -// #[default] -// NormalDisconnection, -// DisconnectWithWillMessage, -// UnspecifiedError, -// MalformedPacket, -// ProtocolError, -// ImplementationSpecificError, -// NotAuthorized, -// ServerBusy, -// ServerShuttingDown, -// KeepAliveTimeout, -// SessionTakenOver, -// TopicFilterInvalid, -// TopicNameInvalid, -// ReceiveMaximumExceeded, -// TopicAliasInvalid, -// PacketTooLarge, -// MessageRateTooHigh, -// QuotaExceeded, -// AdministrativeAction, -// PayloadFormatInvalid, -// RetainNotSupported, -// QosNotSupported, -// UseAnotherServer, -// ServerMoved, -// SharedSubscriptionsNotSupported, -// ConnectionRateExceeded, -// MaximumConnectTime, -// SubscriptionIdentifiersNotSupported, -// WildcardSubscriptionsNotSupported, -// } - -// impl MqttRead for DisconnectReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(DisconnectReasonCode::NormalDisconnection), -// 0x04 => Ok(DisconnectReasonCode::DisconnectWithWillMessage), -// 0x80 => Ok(DisconnectReasonCode::UnspecifiedError), -// 0x81 => Ok(DisconnectReasonCode::MalformedPacket), -// 0x82 => Ok(DisconnectReasonCode::ProtocolError), -// 0x83 => Ok(DisconnectReasonCode::ImplementationSpecificError), -// 0x87 => Ok(DisconnectReasonCode::NotAuthorized), -// 0x89 => Ok(DisconnectReasonCode::ServerBusy), -// 0x8B => Ok(DisconnectReasonCode::ServerShuttingDown), -// 0x8D => Ok(DisconnectReasonCode::KeepAliveTimeout), -// 0x8E => Ok(DisconnectReasonCode::SessionTakenOver), -// 0x8F => Ok(DisconnectReasonCode::TopicFilterInvalid), -// 0x90 => Ok(DisconnectReasonCode::TopicNameInvalid), -// 0x93 => Ok(DisconnectReasonCode::ReceiveMaximumExceeded), -// 0x94 => Ok(DisconnectReasonCode::TopicAliasInvalid), -// 0x95 => Ok(DisconnectReasonCode::PacketTooLarge), -// 0x96 => Ok(DisconnectReasonCode::MessageRateTooHigh), -// 0x97 => Ok(DisconnectReasonCode::QuotaExceeded), -// 0x98 => Ok(DisconnectReasonCode::AdministrativeAction), -// 0x99 => Ok(DisconnectReasonCode::PayloadFormatInvalid), -// 0x9A => Ok(DisconnectReasonCode::RetainNotSupported), -// 0x9B => Ok(DisconnectReasonCode::QosNotSupported), -// 0x9C => Ok(DisconnectReasonCode::UseAnotherServer), -// 0x9D => Ok(DisconnectReasonCode::ServerMoved), -// 0x9E => Ok(DisconnectReasonCode::SharedSubscriptionsNotSupported), -// 0x9F => Ok(DisconnectReasonCode::ConnectionRateExceeded), -// 0xA0 => Ok(DisconnectReasonCode::MaximumConnectTime), -// 0xA1 => Ok(DisconnectReasonCode::SubscriptionIdentifiersNotSupported), -// 0xA2 => Ok(DisconnectReasonCode::WildcardSubscriptionsNotSupported), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttWrite for DisconnectReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// DisconnectReasonCode::NormalDisconnection => 0x00, -// DisconnectReasonCode::DisconnectWithWillMessage => 0x04, -// DisconnectReasonCode::UnspecifiedError => 0x80, -// DisconnectReasonCode::MalformedPacket => 0x81, -// DisconnectReasonCode::ProtocolError => 0x82, -// DisconnectReasonCode::ImplementationSpecificError => 0x83, -// DisconnectReasonCode::NotAuthorized => 0x87, -// DisconnectReasonCode::ServerBusy => 0x89, -// DisconnectReasonCode::ServerShuttingDown => 0x8B, - -// DisconnectReasonCode::KeepAliveTimeout => 0x8D, -// DisconnectReasonCode::SessionTakenOver => 0x8E, -// DisconnectReasonCode::TopicFilterInvalid => 0x8F, - -// DisconnectReasonCode::TopicNameInvalid => 0x90, -// DisconnectReasonCode::ReceiveMaximumExceeded => 0x93, -// DisconnectReasonCode::TopicAliasInvalid => 0x94, -// DisconnectReasonCode::PacketTooLarge => 0x95, -// DisconnectReasonCode::MessageRateTooHigh => 0x96, -// DisconnectReasonCode::QuotaExceeded => 0x97, -// DisconnectReasonCode::AdministrativeAction => 0x98, -// DisconnectReasonCode::PayloadFormatInvalid => 0x99, -// DisconnectReasonCode::RetainNotSupported => 0x9A, -// DisconnectReasonCode::QosNotSupported => 0x9B, -// DisconnectReasonCode::UseAnotherServer => 0x9C, -// DisconnectReasonCode::ServerMoved => 0x9D, -// DisconnectReasonCode::SharedSubscriptionsNotSupported => 0x9E, -// DisconnectReasonCode::ConnectionRateExceeded => 0x9F, -// DisconnectReasonCode::MaximumConnectTime => 0xA0, -// DisconnectReasonCode::SubscriptionIdentifiersNotSupported => 0xA1, -// DisconnectReasonCode::WildcardSubscriptionsNotSupported => 0xA2, -// }; - -// buf.put_u8(val); - -// Ok(()) -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/macros/mod.rs b/mqrstt/src/packets/macros/mod.rs index 79836ca..22e829f 100644 --- a/mqrstt/src/packets/macros/mod.rs +++ b/mqrstt/src/packets/macros/mod.rs @@ -1,14 +1,5 @@ mod properties_macros; mod reason_code_macros; - -pub(crate) use reason_code_macros::*; pub(crate) use properties_macros::*; - -// macro_rules! assert_length { -// ($len:ident, $read:expr) => { -// if len != (read) { -// return Err(DeserializeError::InvalidLength(std::any::type_name::(), len, read)); -// } -// }; -// } \ No newline at end of file +pub(crate) use reason_code_macros::*; diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 349e1b0..31d3707 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -36,10 +36,10 @@ pub use subscribe::*; pub use unsuback::*; pub use unsubscribe::*; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use std::fmt::Display; -use self::error::{DeserializeError, ReadBytes, SerializeError}; +use self::error::{DeserializeError, SerializeError}; use self::mqtt_trait::{PacketRead, PacketWrite, WireLength}; // ==================== Packets ==================== @@ -226,7 +226,10 @@ impl Packet { } #[cfg(test)] - pub(crate) fn read_from_buffer(buffer: &mut BytesMut) -> Result> { + pub(crate) fn read_from_buffer(buffer: &mut BytesMut) -> Result> { + use bytes::Buf; + use error::ReadBytes; + let (header, header_length) = FixedHeader::read_fixed_header(buffer.iter())?; if header.remaining_length + header_length > buffer.len() { return Err(ReadBytes::InsufficientBytes(header.remaining_length + header_length - buffer.len())); @@ -322,15 +325,6 @@ impl std::fmt::Display for PacketType { mod tests { use bytes::BytesMut; - use crate::packets::connack::{ConnAck, ConnAckFlags, ConnAckProperties}; - use crate::packets::disconnect::{Disconnect, DisconnectProperties}; - use crate::packets::QoS; - - use crate::packets::connack::ConnAckReasonCode; - use crate::packets::disconnect::DisconnectReasonCode; - use crate::packets::publish::{Publish, PublishProperties}; - use crate::packets::pubrel::PubRelReasonCode; - use crate::packets::pubrel::{PubRel, PubRelProperties}; use crate::packets::Packet; use crate::tests::test_packets::{disconnect_case, ping_req_case, ping_resp_case, publish_case, pubrel_case, pubrel_smallest_case}; diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs index 0bd95be..daada69 100644 --- a/mqrstt/src/packets/mqtt_trait/mod.rs +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -1,78 +1,28 @@ mod primitive_impl; -use std::{future::Future}; +use std::future::Future; use bytes::{Bytes, BytesMut}; -// mod sealed { -// /// Sealed trait to prevent downstream users from implementing the -// /// [`crate::packets::mqtt_trait::PacketRead`], [`crate::packets::mqtt_trait::PacketWrite`], -// /// [`crate::packets::mqtt_trait::PacketAsyncRead`] [`crate::packets::mqtt_trait::PacketAsyncWrite`], -// /// [`crate::packets::mqtt_trait::MqttRead`], [`crate::packets::mqtt_trait::MqttWrite`] -// /// and [`crate::packets::mqtt_trait::WireLength`] traits. -// pub trait Sealed {} -// impl Sealed for crate::packets::ConnAck {} - -// } - -// pub(crate) trait PacketRead: Sized + sealed::Sealed { -// fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; -// } - -// pub(crate) trait PacketAsyncRead: Sized + sealed::Sealed where S: tokio::io::AsyncReadExt + Unpin { -// fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; -// } - -// pub(crate) trait PacketWrite: Sized + sealed::Sealed { -// fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; -// } - -// pub(crate) trait WireLength: sealed::Sealed { -// fn wire_len(&self) -> usize; -// } - -// pub(crate) trait MqttRead: Sized + sealed::Sealed { -// fn read(buf: &mut Bytes) -> Result; -// } -// pub trait MqttAsyncRead: Sized + sealed::Sealed -// { -// /// Reads `Self` from the provided stream. -// /// Returns the deserialized instance and the number of bytes read from the stream. -// fn async_read(stream: &mut S) -> impl Future>; -// } - - -// pub trait MqttWrite: Sized + sealed::Sealed { -// fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; -// } - -// impl<'a, T> MqttWrite for &'a T -// where -// T: MqttWrite, -// &'a T: sealed::Sealed, -// { -// fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError> { -// ::write(self, buf) -// } -// } - -// pub trait PacketValidation: Sized + sealed::Sealed { -// fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; -// } - -pub(crate) trait PacketRead: Sized { +pub(crate) trait PacketRead: Sized { fn read(flags: u8, remaining_length: usize, buf: Bytes) -> Result; } -pub(crate) trait PacketAsyncRead: Sized where S: tokio::io::AsyncReadExt + Unpin { +pub(crate) trait PacketAsyncRead: Sized +where + S: tokio::io::AsyncReadExt + Unpin, +{ fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; } -pub(crate) trait PacketAsyncWrite: Sized where S: tokio::io::AsyncWriteExt + Unpin { +pub(crate) trait PacketAsyncWrite: Sized +where + S: tokio::io::AsyncWriteExt + Unpin, +{ fn async_write(&self, stream: &mut S) -> impl Future>; } -pub(crate) trait PacketWrite: Sized { +pub(crate) trait PacketWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; } @@ -80,17 +30,15 @@ pub(crate) trait WireLength { fn wire_len(&self) -> usize; } -pub(crate) trait MqttRead: Sized { +pub(crate) trait MqttRead: Sized { fn read(buf: &mut Bytes) -> Result; } -pub(crate) trait MqttAsyncRead: Sized -{ +pub(crate) trait MqttAsyncRead: Sized { /// Reads `Self` from the provided stream. /// Returns the deserialized instance and the number of bytes read from the stream. fn async_read(stream: &mut S) -> impl Future>; } - pub trait MqttWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; } diff --git a/mqrstt/src/packets/primitive/fixed_header.rs b/mqrstt/src/packets/primitive/fixed_header.rs index cb49a93..9f28a71 100644 --- a/mqrstt/src/packets/primitive/fixed_header.rs +++ b/mqrstt/src/packets/primitive/fixed_header.rs @@ -9,17 +9,17 @@ use crate::packets::{ use super::read_fixed_header_rem_len; -// 2.1.1 Fixed Header -// ``` -// 7 3 0 -// +--------------------------+--------------------------+ -// byte 1 | MQTT Control Packet Type | Flags for Packet type | -// +--------------------------+--------------------------+ -// | Remaining Length | -// +-----------------------------------------------------+ -// -// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 -// ``` +/// 2.1.1 Fixed Header +/// ``` +/// 7 3 0 +/// +--------------------------+--------------------------+ +/// byte 1 | MQTT Control Packet Type | Flags for Packet type | +/// +--------------------------+--------------------------+ +/// | Remaining Length | +/// +-----------------------------------------------------+ +/// +/// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 +/// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] pub(crate) struct FixedHeader { pub packet_type: PacketType, diff --git a/mqrstt/src/packets/puback/mod.rs b/mqrstt/src/packets/puback/mod.rs index 4420837..a6e35cc 100644 --- a/mqrstt/src/packets/puback/mod.rs +++ b/mqrstt/src/packets/puback/mod.rs @@ -16,28 +16,40 @@ pub struct PubAck { pub properties: PubAckProperties, } -impl PacketAsyncRead for PubAck where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for PubAck +where + S: tokio::io::AsyncReadExt + Unpin, +{ async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { let packet_identifier = stream.read_u16().await?; if remaining_length == 2 { - Ok((Self { - packet_identifier, - reason_code: PubAckReasonCode::Success, - properties: PubAckProperties::default(), - }, 2)) + Ok(( + Self { + packet_identifier, + reason_code: PubAckReasonCode::Success, + properties: PubAckProperties::default(), + }, + 2, + )) } else if remaining_length < 4 { - return Err(crate::packets::error::ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), remaining_length, 4))); + return Err(crate::packets::error::ReadError::DeserializeError(DeserializeError::InsufficientData( + std::any::type_name::(), + remaining_length, + 4, + ))); } else { let (reason_code, reason_code_read_bytes) = PubAckReasonCode::async_read(stream).await?; let (properties, properties_read_bytes) = PubAckProperties::async_read(stream).await?; - - Ok((Self { - packet_identifier, - reason_code, - properties, - }, 2 + reason_code_read_bytes + properties_read_bytes)) - } + Ok(( + Self { + packet_identifier, + reason_code, + properties, + }, + 2 + reason_code_read_bytes + properties_read_bytes, + )) + } } } @@ -100,22 +112,7 @@ impl WireLength for PubAck { } } -crate::packets::macros::define_properties!(PubAckProperties, - ReasonString, - UserProperty -); - -// #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -// pub struct PubAckProperties { -// pub reason_string: Option>, -// pub user_properties: Vec<(Box, Box)>, -// } - -// impl PubAckProperties { -// pub fn is_empty(&self) -> bool { -// self.reason_string.is_none() && self.user_properties.is_empty() -// } -// } +crate::packets::macros::define_properties!(PubAckProperties, ReasonString, UserProperty); impl MqttRead for PubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { @@ -169,20 +166,6 @@ impl MqttWrite for PubAckProperties { } } -// impl WireLength for PubAckProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(reason_string) = &self.reason_string { -// len += reason_string.wire_len() + 1; -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } - -// len -// } -// } - #[cfg(test)] mod tests { use crate::packets::{ diff --git a/mqrstt/src/packets/puback/reason_code.rs b/mqrstt/src/packets/puback/reason_code.rs index c38f7da..ea78a45 100644 --- a/mqrstt/src/packets/puback/reason_code.rs +++ b/mqrstt/src/packets/puback/reason_code.rs @@ -1,4 +1,3 @@ - crate::packets::macros::reason_code!( PubAckReasonCode, Success, @@ -11,57 +10,3 @@ crate::packets::macros::reason_code!( QuotaExceeded, PayloadFormatInvalid ); - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum PubAckReasonCode { -// Success, -// NoMatchingSubscribers, -// UnspecifiedError, -// ImplementationSpecificError, -// NotAuthorized, -// TopicNameInvalid, -// PacketIdentifierInUse, -// QuotaExceeded, -// PayloadFormatInvalid, -// } - -// impl MqttRead for PubAckReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(PubAckReasonCode::Success), -// 0x10 => Ok(PubAckReasonCode::NoMatchingSubscribers), -// 0x80 => Ok(PubAckReasonCode::UnspecifiedError), -// 0x83 => Ok(PubAckReasonCode::ImplementationSpecificError), -// 0x87 => Ok(PubAckReasonCode::NotAuthorized), -// 0x90 => Ok(PubAckReasonCode::TopicNameInvalid), -// 0x91 => Ok(PubAckReasonCode::PacketIdentifierInUse), -// 0x97 => Ok(PubAckReasonCode::QuotaExceeded), -// 0x99 => Ok(PubAckReasonCode::PayloadFormatInvalid), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttWrite for PubAckReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// PubAckReasonCode::Success => 0x00, -// PubAckReasonCode::NoMatchingSubscribers => 0x10, -// PubAckReasonCode::UnspecifiedError => 0x80, -// PubAckReasonCode::ImplementationSpecificError => 0x83, -// PubAckReasonCode::NotAuthorized => 0x87, -// PubAckReasonCode::TopicNameInvalid => 0x90, -// PubAckReasonCode::PacketIdentifierInUse => 0x91, -// PubAckReasonCode::QuotaExceeded => 0x97, -// PubAckReasonCode::PayloadFormatInvalid => 0x99, -// }; - -// buf.put_u8(val); - -// Ok(()) -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubcomp/properties.rs b/mqrstt/src/packets/pubcomp/properties.rs index 3cbb8cf..f508600 100644 --- a/mqrstt/src/packets/pubcomp/properties.rs +++ b/mqrstt/src/packets/pubcomp/properties.rs @@ -1,26 +1,11 @@ - - use crate::packets::{ error::DeserializeError, mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger, }; - crate::packets::macros::define_properties!(PubCompProperties, ReasonString, UserProperty); -// #[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] -// pub struct PubCompProperties { -// pub reason_string: Option>, -// pub user_properties: Vec<(Box, Box)>, -// } - -// impl PubCompProperties { -// pub fn is_empty(&self) -> bool { -// self.reason_string.is_none() && self.user_properties.is_empty() -// } -// } - impl MqttRead for PubCompProperties { fn read(buf: &mut bytes::Bytes) -> Result { let (len, _) = VariableInteger::read_variable_integer(buf)?; @@ -72,17 +57,3 @@ impl MqttWrite for PubCompProperties { Ok(()) } } - -// impl WireLength for PubCompProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(reason_string) = &self.reason_string { -// len += reason_string.wire_len() + 1; -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } - -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index c63063e..2570309 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -2,10 +2,8 @@ mod properties; pub use properties::PublishProperties; use tokio::io::AsyncReadExt; -use tokio::io::AsyncBufRead; - -use bytes::{BufMut, Bytes}; +use bytes::BufMut; use crate::error::PacketValidationError; use crate::util::constants::MAXIMUM_TOPIC_SIZE; @@ -13,7 +11,8 @@ use crate::util::constants::MAXIMUM_TOPIC_SIZE; use super::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; use super::VariableInteger; use super::{ - error::{DeserializeError, SerializeError}, QoS, + error::{DeserializeError, SerializeError}, + QoS, }; #[derive(Debug, Default, Clone, PartialEq, Eq)] @@ -84,7 +83,10 @@ impl PacketRead for Publish { } } -impl PacketAsyncRead for Publish where S: tokio::io::AsyncRead + Unpin { +impl PacketAsyncRead for Publish +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let mut total_read_bytes = 0; @@ -94,7 +96,9 @@ impl PacketAsyncRead for Publish where S: tokio::io::AsyncRead + Unpin { let (topic, topic_read_bytes) = Box::::async_read(stream).await?; total_read_bytes += topic_read_bytes; - let packet_identifier = if qos == QoS::AtMostOnce { None } else { + let packet_identifier = if qos == QoS::AtMostOnce { + None + } else { total_read_bytes += 2; Some(stream.read_u16().await?) }; @@ -107,21 +111,18 @@ impl PacketAsyncRead for Publish where S: tokio::io::AsyncRead + Unpin { assert_eq!(payload_read_bytes, payload_len); - - Ok( - ( - Self { - dup, - qos, - retain, - topic, - packet_identifier, - publish_properties, - payload, - }, - total_read_bytes + payload_read_bytes - ) - ) + Ok(( + Self { + dup, + qos, + retain, + topic, + packet_identifier, + publish_properties, + payload, + }, + total_read_bytes + payload_read_bytes, + )) } } } @@ -171,14 +172,13 @@ impl PacketValidation for Publish { } } - - #[cfg(test)] mod tests { use bytes::{BufMut, BytesMut}; use crate::packets::{ - mqtt_trait::{PacketRead, PacketWrite}, VariableInteger, + mqtt_trait::{PacketRead, PacketWrite}, + VariableInteger, }; use super::Publish; diff --git a/mqrstt/src/packets/publish/properties.rs b/mqrstt/src/packets/publish/properties.rs index 9d7790e..dac9f39 100644 --- a/mqrstt/src/packets/publish/properties.rs +++ b/mqrstt/src/packets/publish/properties.rs @@ -8,7 +8,8 @@ use crate::packets::{ PacketType, PropertyType, }; -crate::packets::macros::define_properties!(PublishProperties, +crate::packets::macros::define_properties!( + PublishProperties, PayloadFormatIndicator, MessageExpiryInterval, ContentType, @@ -19,43 +20,6 @@ crate::packets::macros::define_properties!(PublishProperties, UserProperty ); - -// #[derive(Debug, Clone, PartialEq, Eq, Default)] -// pub struct PublishProperties { -// /// 3.3.2.3.2 Payload Format Indicator -// /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. -// pub payload_format_indicator: Option, - -// /// 3.3.2.3.3 Message Expiry Interval -// /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. -// pub message_expiry_interval: Option, - -// /// 3.3.2.3.9 Content Type -// /// 3 (0x03) Identifier of the Content Type -// pub content_type: Option>, - -// /// 3.3.2.3.5 Response Topic -// /// 8 (0x08) Byte, Identifier of the Response Topic. -// pub response_topic: Option>, - -// /// 3.3.2.3.6 Correlation Data -// /// 9 (0x09) Byte, Identifier of the Correlation Data. -// pub correlation_data: Option, - -// /// 3.3.2.3.8 Subscription Identifier -// /// 11 (0x0B), Identifier of the Subscription Identifier. -// pub subscription_identifier: Vec, - -// /// 3.3.2.3.4 Topic Alias -// /// 35 (0x23) Byte, Identifier of the Topic Alias. -// pub topic_alias: Option, - -// /// 3.3.2.3.7 User Property -// /// 38 (0x26) Byte, Identifier of the User Property. -// pub user_properties: Vec<(Box, Box)>, - -// } - impl MqttRead for PublishProperties { fn read(buf: &mut bytes::Bytes) -> Result { let (len, _) = VariableInteger::read_variable_integer(buf).map_err(DeserializeError::from)?; @@ -164,36 +128,3 @@ impl MqttWrite for PublishProperties { Ok(()) } } - -// impl WireLength for PublishProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; - -// if self.payload_format_indicator.is_some() { -// len += 2; -// } -// if self.message_expiry_interval.is_some() { -// len += 5; -// } -// if self.topic_alias.is_some() { -// len += 3; -// } -// if let Some(response_topic) = &self.response_topic { -// len += 1 + response_topic.wire_len(); -// } -// if let Some(correlation_data) = &self.correlation_data { -// len += 1 + correlation_data.wire_len(); -// } -// for sub_id in &self.subscription_identifier { -// len += 1 + *sub_id.variable_integer_len(); -// } -// for (key, val) in &self.user_properties { -// len += 1 + key.wire_len() + val.wire_len(); -// } -// if let Some(content_type) = &self.content_type { -// len += 1 + content_type.wire_len(); -// } - -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrec/properties.rs b/mqrstt/src/packets/pubrec/properties.rs index 66b8fcf..146a149 100644 --- a/mqrstt/src/packets/pubrec/properties.rs +++ b/mqrstt/src/packets/pubrec/properties.rs @@ -1,22 +1,11 @@ use crate::packets::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, }; crate::packets::macros::define_properties!(PubRecProperties, ReasonString, UserProperty); -// #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] -// pub struct PubRecProperties { -// pub reason_string: Option>, -// pub user_properties: Vec<(Box, Box)>, -// } - -// impl PubRecProperties { -// pub fn is_empty(&self) -> bool { -// self.reason_string.is_none() && self.user_properties.is_empty() -// } -// } - impl MqttRead for PubRecProperties { fn read(buf: &mut bytes::Bytes) -> Result { let (len, _) = VariableInteger::read_variable_integer(buf)?; @@ -68,17 +57,3 @@ impl MqttWrite for PubRecProperties { Ok(()) } } - -// impl WireLength for PubRecProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(reason_string) = &self.reason_string { -// len += reason_string.wire_len() + 1; -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } - -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrec/reason_code.rs b/mqrstt/src/packets/pubrec/reason_code.rs index 327a79c..6d7cf6e 100644 --- a/mqrstt/src/packets/pubrec/reason_code.rs +++ b/mqrstt/src/packets/pubrec/reason_code.rs @@ -1,6 +1,5 @@ - - -crate::packets::macros::reason_code!(PubRecReasonCode, +crate::packets::macros::reason_code!( + PubRecReasonCode, Success, NoMatchingSubscribers, UnspecifiedError, @@ -11,57 +10,3 @@ crate::packets::macros::reason_code!(PubRecReasonCode, QuotaExceeded, PayloadFormatInvalid ); - - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum PubRecReasonCode { -// Success, -// NoMatchingSubscribers, -// UnspecifiedError, -// ImplementationSpecificError, -// NotAuthorized, -// TopicNameInvalid, -// PacketIdentifierInUse, -// QuotaExceeded, -// PayloadFormatInvalid, -// } - -// impl MqttRead for PubRecReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(PubRecReasonCode::Success), -// 0x10 => Ok(PubRecReasonCode::NoMatchingSubscribers), -// 0x80 => Ok(PubRecReasonCode::UnspecifiedError), -// 0x83 => Ok(PubRecReasonCode::ImplementationSpecificError), -// 0x87 => Ok(PubRecReasonCode::NotAuthorized), -// 0x90 => Ok(PubRecReasonCode::TopicNameInvalid), -// 0x91 => Ok(PubRecReasonCode::PacketIdentifierInUse), -// 0x97 => Ok(PubRecReasonCode::QuotaExceeded), -// 0x99 => Ok(PubRecReasonCode::PayloadFormatInvalid), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttWrite for PubRecReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// PubRecReasonCode::Success => 0x00, -// PubRecReasonCode::NoMatchingSubscribers => 0x10, -// PubRecReasonCode::UnspecifiedError => 0x80, -// PubRecReasonCode::ImplementationSpecificError => 0x83, -// PubRecReasonCode::NotAuthorized => 0x87, -// PubRecReasonCode::TopicNameInvalid => 0x90, -// PubRecReasonCode::PacketIdentifierInUse => 0x91, -// PubRecReasonCode::QuotaExceeded => 0x97, -// PubRecReasonCode::PayloadFormatInvalid => 0x99, -// }; - -// buf.put_u8(val); -// Ok(()) -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrel/properties.rs b/mqrstt/src/packets/pubrel/properties.rs index 6ccad2a..f4a97c8 100644 --- a/mqrstt/src/packets/pubrel/properties.rs +++ b/mqrstt/src/packets/pubrel/properties.rs @@ -1,12 +1,10 @@ use crate::packets::{ - error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger + error::DeserializeError, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, + PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!(PubRelProperties, - ReasonString, - UserProperty -); +crate::packets::macros::define_properties!(PubRelProperties, ReasonString, UserProperty); impl PubRelProperties { pub fn is_empty(&self) -> bool { @@ -46,47 +44,6 @@ impl MqttRead for PubRelProperties { } } -// impl MqttAsyncRead for PubRelProperties where S: tokio::io::AsyncReadExt + Unpin { -// async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { -// let (len, length_variable_integer) = read_async_variable_integer(stream).await?; -// if len == 0 { -// return Ok((Self::default(), length_variable_integer)); -// } - -// let mut properties = PubRelProperties::default(); - -// let mut read_property_bytes = 0; -// loop { -// let (prop, read_bytes) = PropertyType::async_read(stream).await?; -// read_property_bytes += read_bytes; -// match prop { -// PropertyType::ReasonString => { -// if properties.reason_string.is_some() { -// return Err(super::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty(PropertyType::ReasonString))); -// } -// let (prop_body, read_bytes) = Box::::async_read(stream).await?; -// read_property_bytes += read_bytes; -// properties.reason_string = Some(prop_body); -// } -// PropertyType::UserProperty => { -// let (prop_body_key, read_bytes) = Box::::async_read(stream).await?; -// read_property_bytes += read_bytes; -// let (prop_body_value, read_bytes) = Box::::async_read(stream).await?; -// read_property_bytes += read_bytes; - -// properties.user_properties.push((prop_body_key, prop_body_value)) -// }, -// e => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), -// } -// if read_property_bytes == len { -// break; -// } -// } - -// Ok((properties, length_variable_integer + read_property_bytes)) -// } -// } - impl MqttWrite for PubRelProperties { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { let len = self.wire_len(); @@ -106,17 +63,3 @@ impl MqttWrite for PubRelProperties { Ok(()) } } - -// impl WireLength for PubRelProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(reason_string) = &self.reason_string { -// len += reason_string.wire_len() + 1; -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } - -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/pubrel/reason_code.rs b/mqrstt/src/packets/pubrel/reason_code.rs index 275410b..24c44f6 100644 --- a/mqrstt/src/packets/pubrel/reason_code.rs +++ b/mqrstt/src/packets/pubrel/reason_code.rs @@ -1,49 +1 @@ -crate::packets::macros::reason_code!( - PubRelReasonCode, - Success, - PacketIdentifierNotFound -); - - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum PubRelReasonCode { -// Success, -// PacketIdentifierNotFound, -// } - -// impl MqttRead for PubRelReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(PubRelReasonCode::Success), -// 0x92 => Ok(PubRelReasonCode::PacketIdentifierNotFound), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttAsyncRead for PubRelReasonCode where S: tokio::io::AsyncReadExt + Unpin { -// async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { -// let code = match stream.read_u8().await? { -// 0x00 => PubRelReasonCode::Success, -// 0x92 => PubRelReasonCode::PacketIdentifierNotFound, -// t => return Err(super::error::ReadError::DeserializeError(DeserializeError::UnknownProperty(t))), -// }; -// Ok((code, 1)) -// } -// } - -// impl MqttWrite for PubRelReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { -// let val = match self { -// PubRelReasonCode::Success => 0x00, -// PubRelReasonCode::PacketIdentifierNotFound => 0x92, -// }; - -// buf.put_u8(val); -// Ok(()) -// } -// } +crate::packets::macros::reason_code!(PubRelReasonCode, Success, PacketIdentifierNotFound); diff --git a/mqrstt/src/packets/suback/properties.rs b/mqrstt/src/packets/suback/properties.rs index 8a06e7e..b88c5b8 100644 --- a/mqrstt/src/packets/suback/properties.rs +++ b/mqrstt/src/packets/suback/properties.rs @@ -1,26 +1,12 @@ use crate::packets::{ error::DeserializeError, - mqtt_trait::{MqttRead, MqttWrite, WireLength}, + mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, }; use crate::packets::primitive::VariableInteger; -crate::packets::macros::define_properties!(SubAckProperties, - SubscriptionIdentifier, - UserProperty -); - -// #[derive(Debug, Default, PartialEq, Eq, Clone)] -// pub struct SubAckProperties { -// /// 3.8.2.1.2 Subscription Identifier -// /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. -// pub subscription_id: Option, - -// /// 3.8.2.1.3 User Property -// /// 38 (0x26) Byte, Identifier of the User Property. -// pub user_properties: Vec<(Box, Box)>, -// } +crate::packets::macros::define_properties!(SubAckProperties, SubscriptionIdentifier, UserProperty); impl MqttRead for SubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { @@ -30,7 +16,7 @@ impl MqttRead for SubAckProperties { if len == 0 { return Ok(properties); - } else if buf.len() < len { + } else if buf.len() < len { return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), len)); } @@ -77,16 +63,3 @@ impl MqttWrite for SubAckProperties { Ok(()) } } - -// impl WireLength for SubAckProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(sub_id) = self.subscription_id { -// len += 1 + sub_id.variable_integer_len(); -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } -// len -// } -// } diff --git a/mqrstt/src/packets/suback/reason_code.rs b/mqrstt/src/packets/suback/reason_code.rs index 3185a87..0b89706 100644 --- a/mqrstt/src/packets/suback/reason_code.rs +++ b/mqrstt/src/packets/suback/reason_code.rs @@ -1,5 +1,3 @@ - - crate::packets::macros::reason_code!( SubAckReasonCode, GrantedQoS0, @@ -15,47 +13,3 @@ crate::packets::macros::reason_code!( SubscriptionIdentifiersNotSupported, WildcardSubscriptionsNotSupported ); - - - - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum SubAckReasonCode { -// GrantedQoS0, -// GrantedQoS1, -// GrantedQoS2, -// ImplementationSpecificError, -// NotAuthorized, -// TopicFilterInvalid, -// PacketIdentifierInUse, -// QuotaExceeded, -// SharedSubscriptionsNotSupported, -// SubscriptionIdentifiersNotSupported, -// WildcardSubscriptionsNotSupported, -// } - -// impl MqttRead for SubAckReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(SubAckReasonCode::GrantedQoS0), -// 0x01 => Ok(SubAckReasonCode::GrantedQoS1), -// 0x02 => Ok(SubAckReasonCode::GrantedQoS2), -// 0x80 => Ok(SubAckReasonCode::UnspecifiedError), -// 0x83 => Ok(SubAckReasonCode::ImplementationSpecificError), -// 0x87 => Ok(SubAckReasonCode::NotAuthorized), -// 0x8F => Ok(SubAckReasonCode::TopicFilterInvalid), -// 0x91 => Ok(SubAckReasonCode::PacketIdentifierInUse), -// 0x97 => Ok(SubAckReasonCode::QuotaExceeded), -// 0x9E => Ok(SubAckReasonCode::SharedSubscriptionsNotSupported), -// 0xA1 => Ok(SubAckReasonCode::SubscriptionIdentifiersNotSupported), -// 0xA2 => Ok(SubAckReasonCode::WildcardSubscriptionsNotSupported), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttWrite for SubAckReasonCode { \ No newline at end of file diff --git a/mqrstt/src/packets/subscribe/mod.rs b/mqrstt/src/packets/subscribe/mod.rs index de9b082..997b105 100644 --- a/mqrstt/src/packets/subscribe/mod.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -1,5 +1,4 @@ mod properties; -use std::ops::Sub; pub use properties::SubscribeProperties; use tokio::io::AsyncReadExt; @@ -9,7 +8,7 @@ use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use super::{ error::DeserializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}, - PacketType, PropertyType, QoS, VariableInteger, + QoS, VariableInteger, }; use bytes::{Buf, BufMut}; @@ -54,14 +53,17 @@ impl PacketRead for Subscribe { } } -impl PacketAsyncRead for Subscribe where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for Subscribe +where + S: tokio::io::AsyncReadExt + Unpin, +{ fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let mut total_read_bytes = 0; let packet_identifier = stream.read_u16().await?; let (properties, proproperties_read_bytes) = SubscribeProperties::async_read(stream).await?; total_read_bytes += 2 + proproperties_read_bytes; - + let mut topics = vec![]; loop { let (topic, topic_read_bytes) = Box::::async_read(stream).await?; @@ -74,16 +76,18 @@ impl PacketAsyncRead for Subscribe where S: tokio::io::AsyncReadExt + Unpi } } - Ok((Self { - packet_identifier, - properties, - topics, - }, total_read_bytes)) + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) } } } - impl PacketWrite for Subscribe { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { buf.put_u16(self.packet_identifier); @@ -167,7 +171,10 @@ impl MqttRead for SubscriptionOptions { } } -impl MqttAsyncRead for SubscriptionOptions where S: tokio::io::AsyncRead + Unpin { +impl MqttAsyncRead for SubscriptionOptions +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(stream: &mut S) -> impl std::future::Future> { async move { let byte = stream.read_u8().await?; @@ -189,7 +196,6 @@ impl MqttAsyncRead for SubscriptionOptions where S: tokio::io::AsyncRead + } } - impl MqttWrite for SubscriptionOptions { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { let byte = (self.retain_handling.into_u8() << 4) | ((self.retain_as_publish as u8) << 3) | ((self.no_local as u8) << 2) | self.qos.into_u8(); diff --git a/mqrstt/src/packets/subscribe/properties.rs b/mqrstt/src/packets/subscribe/properties.rs index 4ef6e36..27559cc 100644 --- a/mqrstt/src/packets/subscribe/properties.rs +++ b/mqrstt/src/packets/subscribe/properties.rs @@ -1,29 +1,10 @@ - - -use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; - use crate::packets::{ error::DeserializeError, mqtt_trait::{MqttRead, MqttWrite, WireLength}, PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!( - SubscribeProperties, - SubscriptionIdentifier, - UserProperty -); - -// #[derive(Debug, Default, PartialEq, Eq, Clone)] -// pub struct SubscribeProperties { -// /// 3.8.2.1.2 Subscription Identifier -// /// 11 (0x0B) Byte, Identifier of the Subscription Identifier. -// pub subscription_id: Option, - -// /// 3.8.2.1.3 User Property -// /// 38 (0x26) Byte, Identifier of the User Property. -// pub user_properties: Vec<(Box, Box)>, -// } +crate::packets::macros::define_properties!(SubscribeProperties, SubscriptionIdentifier, UserProperty); impl MqttRead for SubscribeProperties { fn read(buf: &mut bytes::Bytes) -> Result { @@ -79,16 +60,3 @@ impl MqttWrite for SubscribeProperties { Ok(()) } } - -// impl WireLength for SubscribeProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(sub_id) = self.subscription_identifier { -// len += 1 + sub_id.variable_integer_len(); -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } -// len -// } -// } \ No newline at end of file diff --git a/mqrstt/src/packets/unsuback/properties.rs b/mqrstt/src/packets/unsuback/properties.rs index e16ee0c..138e85a 100644 --- a/mqrstt/src/packets/unsuback/properties.rs +++ b/mqrstt/src/packets/unsuback/properties.rs @@ -1,19 +1,9 @@ - -use crate::packets::error::{DeserializeError}; +use crate::packets::error::DeserializeError; use crate::packets::mqtt_trait::{MqttRead, MqttWrite, WireLength}; use crate::packets::{PacketType, PropertyType, VariableInteger}; crate::packets::macros::define_properties!(UnsubAckProperties, ReasonString, UserProperty); -// #[derive(Debug, Default, PartialEq, Eq, Clone)] -// pub struct UnsubAckProperties { -// /// 3.11.2.1.2 Reason String -// /// 31 (0x1F) Byte, Identifier of the Reason String. -// pub reason_string: Option>, - -// pub user_properties: Vec<(Box, Box)>, -// } - impl MqttRead for UnsubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { let (len, _) = VariableInteger::read_variable_integer(buf)?; @@ -66,16 +56,3 @@ impl MqttWrite for UnsubAckProperties { Ok(()) } } - -// impl WireLength for UnsubAckProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// if let Some(reason_string) = &self.reason_string { -// len += 1 + reason_string.wire_len(); -// } -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } -// len -// } -// } diff --git a/mqrstt/src/packets/unsuback/reason_code.rs b/mqrstt/src/packets/unsuback/reason_code.rs index 292655f..9ac2036 100644 --- a/mqrstt/src/packets/unsuback/reason_code.rs +++ b/mqrstt/src/packets/unsuback/reason_code.rs @@ -8,51 +8,3 @@ crate::packets::macros::reason_code!( TopicFilterInvalid, PacketIdentifierInUse ); - - -// #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] -// pub enum UnsubAckReasonCode { -// Success, -// NoSubscriptionExisted, -// UnspecifiedError, -// ImplementationSpecificError, -// NotAuthorized, -// TopicFilterInvalid, -// PacketIdentifierInUse, -// } - -// impl MqttRead for UnsubAckReasonCode { -// fn read(buf: &mut bytes::Bytes) -> Result { -// if buf.is_empty() { -// return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); -// } - -// match buf.get_u8() { -// 0x00 => Ok(UnsubAckReasonCode::Success), -// 0x11 => Ok(UnsubAckReasonCode::NoSubscriptionExisted), -// 0x80 => Ok(UnsubAckReasonCode::UnspecifiedError), -// 0x83 => Ok(UnsubAckReasonCode::ImplementationSpecificError), -// 0x87 => Ok(UnsubAckReasonCode::NotAuthorized), -// 0x8F => Ok(UnsubAckReasonCode::TopicFilterInvalid), -// 0x91 => Ok(UnsubAckReasonCode::PacketIdentifierInUse), -// t => Err(DeserializeError::UnknownProperty(t)), -// } -// } -// } - -// impl MqttWrite for UnsubAckReasonCode { -// fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), crate::packets::error::SerializeError> { -// let val = match self { -// UnsubAckReasonCode::Success => 0x00, -// UnsubAckReasonCode::NoSubscriptionExisted => 0x11, -// UnsubAckReasonCode::UnspecifiedError => 0x80, -// UnsubAckReasonCode::ImplementationSpecificError => 0x83, -// UnsubAckReasonCode::NotAuthorized => 0x87, -// UnsubAckReasonCode::TopicFilterInvalid => 0x8F, -// UnsubAckReasonCode::PacketIdentifierInUse => 0x91, -// }; - -// buf.put_u8(val); -// Ok(()) -// } -// } diff --git a/mqrstt/src/packets/unsubscribe/properties.rs b/mqrstt/src/packets/unsubscribe/properties.rs index 41c8091..e7c628f 100644 --- a/mqrstt/src/packets/unsubscribe/properties.rs +++ b/mqrstt/src/packets/unsubscribe/properties.rs @@ -6,14 +6,8 @@ use crate::packets::{ PacketType, PropertyType, }; - crate::packets::macros::define_properties!(UnsubscribeProperties, UserProperty); -// #[derive(Debug, Default, PartialEq, Eq, Clone)] -// pub struct UnsubscribeProperties { -// pub user_properties: Vec<(String, String)>, -// } - impl MqttRead for UnsubscribeProperties { fn read(buf: &mut bytes::Bytes) -> Result { let (len, _) = VariableInteger::read_variable_integer(buf)?; @@ -55,13 +49,3 @@ impl MqttWrite for UnsubscribeProperties { Ok(()) } } - -// impl WireLength for UnsubscribeProperties { -// fn wire_len(&self) -> usize { -// let mut len = 0; -// for (key, value) in &self.user_properties { -// len += 1 + key.wire_len() + value.wire_len(); -// } -// len -// } -// } From 173f911077c2a11c1ff78e4cea7b4150cf7748a5 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Sat, 23 Nov 2024 23:31:05 +0100 Subject: [PATCH 12/45] Touch up docs --- mqrstt/src/client.rs | 98 ++++++------- mqrstt/src/connect_options.rs | 12 +- mqrstt/src/error.rs | 5 +- mqrstt/src/event_handlers.rs | 2 - mqrstt/src/lib.rs | 82 +++-------- mqrstt/src/packets/auth/mod.rs | 17 ++- mqrstt/src/packets/auth/properties.rs | 9 +- mqrstt/src/packets/connack/mod.rs | 75 ++++++---- mqrstt/src/packets/connack/properties.rs | 9 +- mqrstt/src/packets/connect/connect_flags.rs | 11 +- .../src/packets/connect/connect_properties.rs | 54 ++----- mqrstt/src/packets/connect/last_will.rs | 28 ++-- .../packets/connect/last_will_properties.rs | 1 + mqrstt/src/packets/connect/mod.rs | 33 +++-- mqrstt/src/packets/disconnect/mod.rs | 66 ++++----- mqrstt/src/packets/error.rs | 8 +- .../src/packets/macros/properties_macros.rs | 134 +++++++++--------- .../src/packets/macros/reason_code_macros.rs | 7 +- mqrstt/src/packets/mod.rs | 4 +- .../src/packets/mqtt_trait/primitive_impl.rs | 48 +++++-- mqrstt/src/packets/primitive/mod.rs | 3 +- .../src/packets/primitive/protocol_version.rs | 16 ++- mqrstt/src/packets/puback/mod.rs | 9 +- mqrstt/src/packets/pubcomp/mod.rs | 48 ++++--- mqrstt/src/packets/pubcomp/properties.rs | 7 +- mqrstt/src/packets/publish/mod.rs | 8 ++ mqrstt/src/packets/publish/properties.rs | 1 + mqrstt/src/packets/pubrec/mod.rs | 46 +++--- mqrstt/src/packets/pubrec/properties.rs | 7 +- mqrstt/src/packets/pubrel/mod.rs | 3 + mqrstt/src/packets/pubrel/properties.rs | 7 +- mqrstt/src/packets/suback/mod.rs | 34 +++-- mqrstt/src/packets/suback/properties.rs | 7 +- mqrstt/src/packets/subscribe/mod.rs | 30 ++-- mqrstt/src/packets/subscribe/properties.rs | 7 +- mqrstt/src/packets/unsuback/mod.rs | 1 + mqrstt/src/packets/unsuback/properties.rs | 7 +- mqrstt/src/packets/unsubscribe/mod.rs | 40 +++--- mqrstt/src/packets/unsubscribe/properties.rs | 6 +- mqrstt/src/smol/network.rs | 2 +- mqrstt/src/tests/test_packets.rs | 4 +- 41 files changed, 529 insertions(+), 467 deletions(-) diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index 4123d73..d9f4335 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -8,22 +8,22 @@ use crate::{ packets::{ mqtt_trait::PacketValidation, DisconnectReasonCode, - Packet, QoS, - // disconnect::{Disconnect, DisconnectProperties}, - // publish::{Publish, PublishProperties}, - // subscribe::{Subscribe, SubscribeProperties, Subscription}, + Packet, + QoS, + // disconnect::{Disconnect, DisconnectProperties}, + // publish::{Publish, PublishProperties}, + // subscribe::{Subscribe, SubscribeProperties, Subscription}, // unsubscribe::{Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, - - {Disconnect, DisconnectProperties}, - {Publish, PublishProperties}, - {Subscribe, SubscribeProperties, Subscription}, + {Disconnect, DisconnectProperties}, + {Publish, PublishProperties}, + {Subscribe, SubscribeProperties, SubscribeTopics}, {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, }, }; #[derive(Debug, Clone)] -/// A Clonable client that can be used to perform MQTT operations -/// +/// A Clonable client that can be used to send MQTT messages +/// /// This object is never self constructed but is a obtained by calling the builder functions on [`crate::NetworkBuilder`] pub struct MqttClient { /// Provides this client with an available packet id or waits on it. @@ -56,8 +56,6 @@ impl MqttClient { } } - - /// This function is only here for you to use during testing of for example your handler /// For control over the input of this type look at [`MqttClient::test_custom_client`] /// @@ -102,7 +100,7 @@ impl MqttClient { impl MqttClient { /// Creates a subscribe packet that is then asynchronously transferred to the Network stack for transmission /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// # Examples /// ``` @@ -133,9 +131,9 @@ impl MqttClient { /// mqtt_client.subscribe(("final/test/topic", sub_options)).await; /// # }); /// ``` - pub async fn subscribe>(&self, into_subscribtions: A) -> Result<(), ClientError> { + pub async fn subscribe>(&self, into_subscribtions: A) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?; - let subscription: Subscription = into_subscribtions.into(); + let subscription: SubscribeTopics = into_subscribtions.into(); let sub = Subscribe::new(pkid, subscription.0); sub.validate(self.max_packet_size)?; @@ -146,7 +144,7 @@ impl MqttClient { /// Creates a subscribe packet with additional subscribe packet properties. /// The packet is then asynchronously transferred to the Network stack for transmission. /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// # Examples /// ``` @@ -160,7 +158,7 @@ impl MqttClient { /// subscription_identifier: Some(1), /// user_properties: vec![], /// }; - /// + /// /// let sub_properties_clone = sub_properties.clone(); /// /// // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::AtMostOnce, @@ -192,7 +190,7 @@ impl MqttClient { /// mqtt_client.subscribe_with_properties(("final/test/topic", sub_options), sub_properties).await; /// # }); /// ``` - pub async fn subscribe_with_properties>(&self, into_sub: S, properties: SubscribeProperties) -> Result<(), ClientError> { + pub async fn subscribe_with_properties>(&self, into_sub: S, properties: SubscribeProperties) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv().await.map_err(|_| ClientError::NoNetworkChannel)?; let sub = Subscribe { packet_identifier: pkid, @@ -276,7 +274,7 @@ impl MqttClient { /// correlation_data: Some("correlation_data".into()), /// ..Default::default() /// }; - /// + /// /// # let properties_clone = properties.clone(); /// /// // publish a message with QoS 0, without a packet identifier @@ -503,12 +501,12 @@ impl MqttClient { impl MqttClient { /// Creates a subscribe packet that is then transferred to the Network stack for transmission /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// This function blocks until the packet is queued for transmission /// Creates a subscribe packet that is then asynchronously transferred to the Network stack for transmission /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// # Examples /// ``` @@ -538,9 +536,9 @@ impl MqttClient { /// mqtt_client.subscribe_blocking(("final/test/topic", sub_options)).unwrap(); /// # }); /// ``` - pub fn subscribe_blocking>(&self, into_subscribtions: A) -> Result<(), ClientError> { + pub fn subscribe_blocking>(&self, into_subscribtions: A) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?; - let subscription: Subscription = into_subscribtions.into(); + let subscription: SubscribeTopics = into_subscribtions.into(); let sub = Subscribe::new(pkid, subscription.0); sub.validate(self.max_packet_size)?; @@ -551,7 +549,7 @@ impl MqttClient { /// Creates a subscribe packet with additional subscribe packet properties. /// The packet is then transferred to the Network stack for transmission. /// - /// Can be called with anything that can be converted into [`Subscription`] + /// Can be called with anything that can be converted into [`SubscribeTopics`] /// /// This function blocks until the packet is queued for transmission /// # Examples @@ -566,7 +564,7 @@ impl MqttClient { /// user_properties: vec![], /// }; /// # let sub_properties_clone = sub_properties.clone(); - /// + /// /// // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::AtMostOnce, /// mqtt_client.subscribe_with_properties_blocking("test/topic", sub_properties).unwrap(); /// @@ -596,7 +594,7 @@ impl MqttClient { /// mqtt_client.subscribe_with_properties_blocking(("final/test/topic", sub_options), sub_properties).unwrap(); /// # }); /// ``` - pub fn subscribe_with_properties_blocking>(&self, into_subscribtions: S, properties: SubscribeProperties) -> Result<(), ClientError> { + pub fn subscribe_with_properties_blocking>(&self, into_subscribtions: S, properties: SubscribeProperties) -> Result<(), ClientError> { let pkid = self.available_packet_ids_r.recv_blocking().map_err(|_| ClientError::NoNetworkChannel)?; let sub = Subscribe { packet_identifier: pkid, @@ -618,7 +616,7 @@ impl MqttClient { /// ``` /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { - /// + /// /// use mqrstt::packets::QoS; /// use bytes::Bytes; /// @@ -674,7 +672,7 @@ impl MqttClient { /// ``` /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { - /// + /// /// use mqrstt::packets::QoS; /// use mqrstt::packets::PublishProperties; /// use bytes::Bytes; @@ -684,7 +682,7 @@ impl MqttClient { /// correlation_data: Some("correlation_data".into()), /// ..Default::default() /// }; - /// + /// /// # let properties_clone = properties.clone(); /// /// // publish a message with QoS 0, without a packet identifier @@ -764,7 +762,7 @@ impl MqttClient { /// // Unsubscribe from multiple topics specified as an array of String: /// let topics = &[String::from("test/topic1"), String::from("test/topic2")]; /// mqtt_client.unsubscribe_blocking(topics.as_slice()); - /// + /// /// # }); /// # std::hint::black_box(network); /// ``` @@ -827,7 +825,7 @@ impl MqttClient { /// // Unsubscribe from multiple topics specified as an array of String: /// let topics = ["test/topic1","test/topic2"]; /// mqtt_client.unsubscribe_with_properties_blocking(topics.as_slice(), properties); - /// + /// /// # }); /// # std::hint::black_box(network); /// ``` @@ -853,7 +851,7 @@ impl MqttClient { /// ``` /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { - /// + /// /// mqtt_client.disconnect_blocking().unwrap(); /// /// # }); @@ -879,7 +877,7 @@ impl MqttClient { /// /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); /// # smol::block_on(async { - /// + /// /// use mqrstt::packets::DisconnectProperties; /// use mqrstt::packets::DisconnectReasonCode; /// @@ -954,15 +952,15 @@ mod tests { #[tokio::test] async fn test_subscribe_with_properties() { let (mqtt_client, client_to_handler_r, to_network_r) = create_new_test_client(); - - let sub_properties = SubscribeProperties{ + + let sub_properties = SubscribeProperties { subscription_identifier: Some(1), user_properties: vec![], }; // retain_handling: RetainHandling::ZERO, retain_as_publish: false, no_local: false, qos: QoS::AtMostOnce, let res = mqtt_client.subscribe_with_properties("test/topic", sub_properties.clone()).await; - + assert!(res.is_ok()); let packet = client_to_handler_r.recv().await.unwrap(); // assert!(matches!(packet, Packet::Subscribe(sub) if sub.properties.subscription_id == Some(1))); @@ -972,7 +970,7 @@ mod tests { } #[test] - + fn test_subscribe_blocking() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); @@ -1023,7 +1021,6 @@ mod tests { std::hint::black_box((client, client_to_handler_r, to_network_r)); } - #[test] fn test_unsubscribe_blocking() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); @@ -1049,12 +1046,11 @@ mod tests { std::hint::black_box((client, client_to_handler_r, to_network_r)); } - #[test] fn test_unsubscribe_with_properties_blocking() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); - let properties = UnsubscribeProperties{ + let properties = UnsubscribeProperties { user_properties: vec![("property".into(), "value".into())], }; @@ -1098,12 +1094,11 @@ mod tests { assert_eq!(res.unwrap_err(), ClientError::ValidationError(PacketValidationError::TopicSize(65538))); } - #[tokio::test] async fn publish_with_properties() { let (client, client_to_handler_r, to_network_r) = create_new_test_client(); - let properties = crate::packets::PublishProperties{ + let properties = crate::packets::PublishProperties { response_topic: Some("response/topic".into()), correlation_data: Some("correlation_other_data".into()), ..Default::default() @@ -1115,7 +1110,7 @@ mod tests { assert!(res.is_ok()); let packet = client_to_handler_r.recv().await.unwrap(); - let publ = Publish{ + let publ = Publish { dup: false, qos: *qos, retain: false, @@ -1131,12 +1126,11 @@ mod tests { std::hint::black_box((client, client_to_handler_r, to_network_r)); } - #[tokio::test] async fn publish_with_just_right_topic_len_properties() { let (client, _client_to_handler_r, _) = create_new_test_client(); - let properties = crate::packets::PublishProperties{ + let properties = crate::packets::PublishProperties { response_topic: Some("response/topic".into()), correlation_data: Some("correlation_data".into()), ..Default::default() @@ -1151,7 +1145,7 @@ mod tests { async fn publish_with_too_long_topic_properties() { let (client, _client_to_handler_r, _) = create_new_test_client(); - let properties = crate::packets::PublishProperties{ + let properties = crate::packets::PublishProperties { response_topic: Some("response/topic".into()), correlation_data: Some("correlation_data".into()), ..Default::default() @@ -1226,12 +1220,10 @@ mod tests { let disconnect = client_to_handler_r.recv().await.unwrap(); assert_eq!(PacketType::Disconnect, disconnect.packet_type()); - assert!( - matches!(disconnect, Packet::Disconnect(res) - if res.properties == DisconnectProperties::default() && - DisconnectReasonCode::NormalDisconnection == res.reason_code - ) - ); + assert!(matches!(disconnect, Packet::Disconnect(res) + if res.properties == DisconnectProperties::default() && + DisconnectReasonCode::NormalDisconnection == res.reason_code + )); } #[tokio::test] @@ -1259,7 +1251,6 @@ mod tests { assert!(matches!(disconnect, Packet::Disconnect(res) if properties == res.properties && DisconnectReasonCode::KeepAliveTimeout == res.reason_code)); } - #[test] fn test_disconnect_blocking() { let (client, client_to_handler_r, _) = create_new_test_client(); @@ -1288,5 +1279,4 @@ mod tests { assert!(matches!(disconnect, Packet::Disconnect(res) if properties == res.properties && DisconnectReasonCode::KeepAliveTimeout == res.reason_code)); } - } diff --git a/mqrstt/src/connect_options.rs b/mqrstt/src/connect_options.rs index 224d5aa..c5a8671 100644 --- a/mqrstt/src/connect_options.rs +++ b/mqrstt/src/connect_options.rs @@ -1,6 +1,5 @@ use std::time::Duration; - use crate::util::constants::DEFAULT_RECEIVE_MAXIMUM; use crate::{ packets::{ConnectProperties, LastWill}, @@ -9,10 +8,11 @@ use crate::{ #[derive(Debug, thiserror::Error)] pub enum ConnectOptionsError { - #[error("Maximum packet size is exceeded. Maximum is {MAXIMUM_PACKET_SIZE}, was provided {0}")] - MaximumPacketSize(u32), + #[error("Maximum packet size is exceeded. Maximum is {MAXIMUM_PACKET_SIZE}, user provided: {0}")] + MaximumPacketSizeExceeded(u32), } +/// Options for the connection to the MQTT broker #[derive(Debug, Clone)] pub struct ConnectOptions { /// client identifier @@ -71,7 +71,7 @@ impl Default for ConnectOptions { impl ConnectOptions { /// Create a new [`ConnectOptions`] - /// + /// /// Be aware: /// This client does not restrict the client identifier in any way. However, the MQTT v5.0 specification does. /// It is thus recommended to use a client id that is compatible with the MQTT v5.0 specification. @@ -121,7 +121,7 @@ impl ConnectOptions { username: self.username.clone(), password: self.password.clone(), connect_properties, - protocol_version: crate::packets::protocol_version::ProtocolVersion::V5, + protocol_version: crate::packets::ProtocolVersion::V5, last_will: self.last_will.clone(), }; @@ -221,7 +221,7 @@ impl ConnectOptions { pub fn set_maximum_packet_size(&mut self, maximum_packet_size: u32) -> Result<&mut Self, ConnectOptionsError> { if maximum_packet_size > MAXIMUM_PACKET_SIZE { - Err(ConnectOptionsError::MaximumPacketSize(maximum_packet_size)) + Err(ConnectOptionsError::MaximumPacketSizeExceeded(maximum_packet_size)) } else { self.maximum_packet_size = Some(maximum_packet_size); Ok(self) diff --git a/mqrstt/src/error.rs b/mqrstt/src/error.rs index 7b6cac7..e4b018c 100644 --- a/mqrstt/src/error.rs +++ b/mqrstt/src/error.rs @@ -4,8 +4,7 @@ use async_channel::{RecvError, SendError}; use crate::packets::{ error::{DeserializeError, ReadBytes, SerializeError}, - ConnAckReasonCode, - {Packet, PacketType}, + ConnAckReasonCode, {Packet, PacketType}, }; /// Critical errors that can happen during the operation of the entire client @@ -43,7 +42,7 @@ pub enum ConnectionError { JoinError(#[from] tokio::task::JoinError), } -/// Errors that the [`crate::StateHandler`] can emit +/// Errors that the internal StateHandler can emit #[derive(Debug, Clone, thiserror::Error)] pub enum HandlerError { #[error("Missing Packet ID")] diff --git a/mqrstt/src/event_handlers.rs b/mqrstt/src/event_handlers.rs index 5ef33f4..f8ef9c9 100644 --- a/mqrstt/src/event_handlers.rs +++ b/mqrstt/src/event_handlers.rs @@ -8,8 +8,6 @@ use crate::packets::Packet; /// This guarantees that the end user has handlded the packet. /// Trait for async mutable access to handler. /// Usefull when you have a single handler - -/// This trait can be used types which pub trait AsyncEventHandler { fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync; } diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index 1de3c95..23b1744 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -1,31 +1,27 @@ //! A pure rust MQTT client which is easy to use, efficient and provides both sync and async options. //! //! Because this crate aims to be runtime agnostic the user is required to provide their own data stream. -//! For an async approach the stream has to implement the `AsyncReadExt` and `AsyncWriteExt` traits. -//! That is [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] for tokio and [`::smol::io::AsyncReadExt`] and [`::smol::io::AsyncWriteExt`] for smol. +//! For an async approach the stream has to implement the `AsyncRead` and `AsyncWrite` traits. +//! That is [`::tokio::io::AsyncRead`] and [`::tokio::io::AsyncWrite`] for tokio and [`::smol::io::AsyncRead`] and [`::smol::io::AsyncWrite`] for smol. //! //! Features: //! ---------------------------- //! - MQTT v5 //! - Runtime agnostic (Smol, Tokio) -//! - TLS/TCP +//! - Packets are acknoledged after handler has processed them +//! - Runs on just a stream so you can use all TCP backends //! - Lean //! - Keep alive depends on actual communication //! //! To do //! ---------------------------- -//! - Enforce size of outbound messages (e.g. Publish) -//! - QUIC via QUINN //! - Even More testing -//! - More documentation -//! - Remove logging calls or move all to test flag //! //! Notes: //! ---------------------------- -//! - Your handler should not wait too long -//! - Create a new connection when an error or disconnect is encountered +//! - While the handler is processing a message the stream blocks. To prevent this, spawn a task in the handler or use [tokio::ConcurrentHandler]. //! - Handlers only get incoming packets -//! - Sync mode requires a non blocking stream +//! - Create a new connection when an error or disconnect is encountered //! //! Smol example: //! ---------------------------- @@ -111,59 +107,6 @@ //! assert!(n.is_ok()); //! } //! ``` -//! -// //! Sync example: -// //! ---------------------------- -// //! ```rust -// //! use mqrstt::{ -// //! MqttClient, -// //! example_handlers::NOP, -// //! ConnectOptions, -// //! packets::{self, Packet}, -// //! EventHandler, -// //! sync::NetworkStatus, -// //! }; -// //! use std::net::TcpStream; -// //! -// //! let mut client_id: String = "SyncTcppingrespTestExample".to_string(); -// //! let options = ConnectOptions::new(client_id); -// //! -// //! let address = "broker.emqx.io"; -// //! let port = 1883; -// //! -// //! let (mut network, client) = new_sync(options); -// //! -// //! // Construct a no op handler -// //! let mut nop = NOP{}; -// //! -// //! // In normal operations you would want to loop connect -// //! // To reconnect after a disconnect or error -// //! let stream = TcpStream::connect((address, port)).unwrap(); -// //! // IMPORTANT: Set nonblocking to true! No progression will be made when stream reads block! -// //! stream.set_nonblocking(true).unwrap(); -// //! network.connect(stream, &mut nop).unwrap(); -// //! -// //! let res_join_handle = std::thread::spawn(move || -// //! loop { -// //! match network.poll(&mut nop) { -// //! Ok(NetworkStatus::ActivePending) => { -// //! std::thread::sleep(std::time::Duration::from_millis(100)); -// //! }, -// //! Ok(NetworkStatus::ActiveReady) => { -// //! std::thread::sleep(std::time::Duration::from_millis(100)); -// //! }, -// //! otherwise => return otherwise, -// //! } -// //! } -// //! ); -// //! -// //! std::thread::sleep(std::time::Duration::from_secs(30)); -// //! client.disconnect_blocking().unwrap(); -// //! let join_res = res_join_handle.join(); -// //! assert!(join_res.is_ok()); -// //! let res = join_res.unwrap(); -// //! assert!(res.is_ok()); -// //! ``` const CHANNEL_SIZE: usize = 100; @@ -173,16 +116,26 @@ mod connect_options; mod state_handler; mod util; +/// Contains the reader writer parts for the smol runtime. +/// +/// Module [`crate::smol`] only contains a synchronized approach to call the users `Handler`. #[cfg(feature = "smol")] pub mod smol; +/// Contains the reader and writer parts for the tokio runtime. +/// +/// Module [`crate::tokio`] contains both a synchronized and concurrent approach to call the users `Handler`. #[cfg(any(feature = "tokio"))] pub mod tokio; +/// Error types that the user can see during operation of the client. +/// +/// Wraps all other errors that can be encountered. pub mod error; + mod event_handlers; +/// All MQTT packets are defined here pub mod packets; mod state; -use std::marker::PhantomData; pub use event_handlers::*; @@ -190,6 +143,7 @@ pub use client::MqttClient; pub use connect_options::ConnectOptions; use state_handler::StateHandler; +use std::marker::PhantomData; #[cfg(test)] pub mod tests; diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs index ad4905c..412d87d 100644 --- a/mqrstt/src/packets/auth/mod.rs +++ b/mqrstt/src/packets/auth/mod.rs @@ -5,12 +5,14 @@ pub use reason_code::AuthReasonCode; use bytes::Bytes; -use super::{mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, VariableInteger}; - +use super::{ + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, +}; #[derive(Debug, Clone, PartialEq, Eq)] /// The AUTH packet is used to perform more intriquite authentication methods. -/// +/// /// At the time of writing this client does not (yet) provide the user a method of handling the auth handshake. /// There are several other ways to perform authentication, for example using TLS. /// Additionally, not many clients support this packet fully. @@ -28,12 +30,15 @@ impl PacketRead for Auth { } } -impl PacketAsyncRead for Auth where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for Auth +where + S: tokio::io::AsyncRead + Unpin, +{ async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { let (reason_code, reason_code_read_bytes) = AuthReasonCode::async_read(stream).await?; let (properties, properties_read_bytes) = AuthProperties::async_read(stream).await?; - Ok((Self { reason_code, properties }, reason_code_read_bytes + properties_read_bytes )) + Ok((Self { reason_code, properties }, reason_code_read_bytes + properties_read_bytes)) } } @@ -49,4 +54,4 @@ impl WireLength for Auth { fn wire_len(&self) -> usize { 1 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() } -} \ No newline at end of file +} diff --git a/mqrstt/src/packets/auth/properties.rs b/mqrstt/src/packets/auth/properties.rs index 1ed1a60..e905897 100644 --- a/mqrstt/src/packets/auth/properties.rs +++ b/mqrstt/src/packets/auth/properties.rs @@ -6,7 +6,14 @@ use crate::packets::{ PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!(AuthProperties, AuthenticationMethod, AuthenticationData, ReasonString, UserProperty); +crate::packets::macros::define_properties!( + /// Properties of the AUTH packet + AuthProperties, + AuthenticationMethod, + AuthenticationData, + ReasonString, + UserProperty +); impl MqttRead for AuthProperties { fn read(buf: &mut Bytes) -> Result { diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 56d0259..1caf53b 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -4,14 +4,19 @@ pub use properties::ConnAckProperties; mod reason_code; pub use reason_code::ConnAckReasonCode; - use super::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, VariableInteger, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, }; use bytes::{Buf, BufMut}; - +/// ConnAck packet is sent by the server in response to a [`crate::packets::Connect`] packet. +/// +/// The ConnAck packet contains the values used by the server related to this connection. +/// +/// For example the requested client identifier can be changed by the server. +/// This is then indicated using the property [`crate::packets::ConnAckProperties::assigned_client_identifier`]. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct ConnAck { /// 3.2.2.1 Connect Acknowledge Flags @@ -43,22 +48,24 @@ impl PacketRead for ConnAck { } } -impl PacketAsyncRead for ConnAck where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for ConnAck +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(_: u8, _: usize, stream: &mut S) -> impl std::future::Future> { async move { let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; - + Ok(( Self { connack_flags, reason_code, connack_properties, }, - read_bytes + reason_code_read_bytes + connack_properties_read_bytes + read_bytes + reason_code_read_bytes + connack_properties_read_bytes, )) - } } } @@ -81,19 +88,24 @@ impl WireLength for ConnAck { } } - #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct ConnAckFlags { pub session_present: bool, } -impl MqttAsyncRead for ConnAckFlags where S: tokio::io::AsyncReadExt + Unpin { +impl MqttAsyncRead for ConnAckFlags +where + S: tokio::io::AsyncReadExt + Unpin, +{ fn async_read(stream: &mut S) -> impl std::future::Future> { async move { let byte = stream.read_u8().await?; - Ok((Self { - session_present: (byte & 0b00000001) == 0b00000001, - }, 1)) + Ok(( + Self { + session_present: (byte & 0b00000001) == 0b00000001, + }, + 1, + )) } } } @@ -125,7 +137,9 @@ impl MqttWrite for ConnAckFlags { mod tests { use crate::packets::{ - connack::{ConnAck, ConnAckProperties}, mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, ConnAckReasonCode, Packet, VariableInteger + connack::{ConnAck, ConnAckProperties}, + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + ConnAckReasonCode, Packet, VariableInteger, }; #[test] @@ -133,23 +147,23 @@ mod tests { let mut buf = bytes::BytesMut::new(); let connack_properties = ConnAckProperties { - session_expiry_interval: Some(60), // Session expiry interval in seconds - receive_maximum: Some(20), // Maximum number of QoS 1 and QoS 2 publications that the client is willing to process concurrently - maximum_qos: Some(crate::packets::QoS::AtMostOnce), // Maximum QoS level supported by the server - retain_available: Some(true), // Whether the server supports retained messages - maximum_packet_size: Some(1024), // Maximum packet size the server is willing to accept - assigned_client_id: Some(Box::from("client-12345")), // Client identifier assigned by the server - topic_alias_maximum: Some(10), // Maximum number of topic aliases supported by the server - reason_string: Some(Box::from("Connection accepted")), // Reason string for the connection acknowledgment - user_properties: vec![(Box::from("key1"), Box::from("value1"))], // User property key-value pair - wildcards_available: Some(true), // Whether wildcard subscriptions are available - subscription_ids_available: Some(true), // Whether subscription identifiers are available - shared_subscription_available: Some(true), // Whether shared subscriptions are available - server_keep_alive: Some(120), // Server keep alive time in seconds - response_info: Some(Box::from("Response info")), // Response information - server_reference: Some(Box::from("server-reference")), // Server reference - authentication_method: Some(Box::from("auth-method")), // Authentication method - authentication_data: Some(vec![1, 2, 3, 4]), // Authentication data + session_expiry_interval: Some(60), // Session expiry interval in seconds + receive_maximum: Some(20), // Maximum number of QoS 1 and QoS 2 publications that the client is willing to process concurrently + maximum_qos: Some(crate::packets::QoS::AtMostOnce), // Maximum QoS level supported by the server + retain_available: Some(true), // Whether the server supports retained messages + maximum_packet_size: Some(1024), // Maximum packet size the server is willing to accept + assigned_client_identifier: Some(Box::from("client-12345")), // Client identifier assigned by the server + topic_alias_maximum: Some(10), // Maximum number of topic aliases supported by the server + reason_string: Some(Box::from("Connection accepted")), // Reason string for the connection acknowledgment + user_properties: vec![(Box::from("key1"), Box::from("value1"))], // User property key-value pair + wildcards_available: Some(true), // Whether wildcard subscriptions are available + subscription_ids_available: Some(true), // Whether subscription identifiers are available + shared_subscription_available: Some(true), // Whether shared subscriptions are available + server_keep_alive: Some(120), // Server keep alive time in seconds + response_info: Some(Box::from("Response info")), // Response information + server_reference: Some(Box::from("server-reference")), // Server reference + authentication_method: Some(Box::from("auth-method")), // Authentication method + authentication_data: Some(vec![1, 2, 3, 4]), // Authentication data }; let len = connack_properties.wire_len(); @@ -160,7 +174,6 @@ mod tests { connack_properties.write(&mut buf).unwrap(); assert_eq!(len + len_of_wire_len, buf.len()); - } #[test] diff --git a/mqrstt/src/packets/connack/properties.rs b/mqrstt/src/packets/connack/properties.rs index 979b978..670b23b 100644 --- a/mqrstt/src/packets/connack/properties.rs +++ b/mqrstt/src/packets/connack/properties.rs @@ -6,6 +6,7 @@ use crate::packets::{ use bytes::BufMut; crate::packets::macros::define_properties!( + /// ConnAck Properties ConnAckProperties, SessionExpiryInterval, ReceiveMaximum, @@ -73,10 +74,10 @@ impl MqttRead for ConnAckProperties { properties.maximum_packet_size = Some(u32::read(&mut property_data)?); } PropertyType::AssignedClientIdentifier => { - if properties.assigned_client_id.is_some() { + if properties.assigned_client_identifier.is_some() { return Err(DeserializeError::DuplicateProperty(PropertyType::AssignedClientIdentifier)); } - properties.assigned_client_id = Some(Box::::read(&mut property_data)?); + properties.assigned_client_identifier = Some(Box::::read(&mut property_data)?); } PropertyType::TopicAliasMaximum => { if properties.topic_alias_maximum.is_some() { @@ -162,7 +163,7 @@ impl MqttWrite for ConnAckProperties { maximum_qos, retain_available, maximum_packet_size, - assigned_client_id, + assigned_client_identifier, topic_alias_maximum, reason_string, user_properties, @@ -196,7 +197,7 @@ impl MqttWrite for ConnAckProperties { PropertyType::MaximumPacketSize.write(buf)?; buf.put_u32(*maximum_packet_size); } - if let Some(client_id) = assigned_client_id { + if let Some(client_id) = assigned_client_identifier { PropertyType::AssignedClientIdentifier.write(buf)?; client_id.write(buf)?; } diff --git a/mqrstt/src/packets/connect/connect_flags.rs b/mqrstt/src/packets/connect/connect_flags.rs index 08a62a8..2066b51 100644 --- a/mqrstt/src/packets/connect/connect_flags.rs +++ b/mqrstt/src/packets/connect/connect_flags.rs @@ -1,9 +1,11 @@ use bytes::{Buf, BufMut}; +use tokio::io::AsyncReadExt; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, QoS, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, + QoS, }; /// The connect flags describe some information related the session. @@ -71,7 +73,10 @@ impl MqttRead for ConnectFlags { } } -impl MqttAsyncRead for ConnectFlags where S: tokio::io::AsyncReadExt + Unpin { +impl MqttAsyncRead for ConnectFlags +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(stream: &mut S) -> impl std::future::Future> { async move { let byte = stream.read_u8().await?; @@ -85,4 +90,4 @@ impl MqttWrite for ConnectFlags { buf.put_u8(self.into_u8()?); Ok(()) } -} \ No newline at end of file +} diff --git a/mqrstt/src/packets/connect/connect_properties.rs b/mqrstt/src/packets/connect/connect_properties.rs index 9a22591..7b3f13a 100644 --- a/mqrstt/src/packets/connect/connect_properties.rs +++ b/mqrstt/src/packets/connect/connect_properties.rs @@ -7,15 +7,15 @@ use crate::packets::{ PacketType, PropertyType, WireLength, }; -// /// Connect Properties -// /// -// /// The wire representation starts with the length of all properties after which -// /// the identifiers and their actual value are given -// /// -// /// 3.1.2.11.1 Property Length -// /// The length of the Properties in the CONNECT packet Variable Header encoded as a Variable Byte Integer. -// /// Followed by all possible connect properties: +// / +// / The wire representation starts with the length of all properties after which +// / the identifiers and their actual value are given +// / +// / 3.1.2.11.1 Property Length +// / The length of the Properties in the CONNECT packet Variable Header encoded as a Variable Byte Integer. +// / Followed by all possible connect properties: crate::packets::macros::define_properties!( + /// Connect Properties ConnectProperties, SessionExpiryInterval, ReceiveMaximum, @@ -155,41 +155,3 @@ impl MqttRead for ConnectProperties { Ok(properties) } } - -// impl WireLength for ConnectProperties { -// fn wire_len(&self) -> usize { -// let mut len: usize = 0; - -// if self.session_expiry_interval.is_some() { -// len += 1 + 4; -// } -// if self.receive_maximum.is_some() { -// len += 1 + 2; -// } -// if self.maximum_packet_size.is_some() { -// len += 1 + 4; -// } -// if self.topic_alias_maximum.is_some() { -// len += 1 + 2; -// } -// if self.request_response_information.is_some() { -// len += 2; -// } -// if self.request_problem_information.is_some() { -// len += 2; -// } -// for (key, value) in &self.user_properties { -// len += 1; -// len += key.wire_len(); -// len += value.wire_len(); -// } -// if let Some(authentication_method) = &self.authentication_method { -// len += 1 + authentication_method.wire_len(); -// } -// if !self.authentication_data.is_empty() && self.authentication_method.is_some() { -// len += 1 + self.authentication_data.wire_len(); -// } - -// len -// } -// } diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs index b195b32..5d89ec0 100644 --- a/mqrstt/src/packets/connect/last_will.rs +++ b/mqrstt/src/packets/connect/last_will.rs @@ -1,9 +1,9 @@ use bytes::{Bytes, BytesMut}; - use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, QoS, WireLength, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, + QoS, WireLength, }; use super::{LastWillProperties, VariableInteger}; @@ -46,20 +46,26 @@ impl LastWill { last_will_properties, }) } - pub(crate) async fn async_read(qos: QoS, retain: bool, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> where S: tokio::io::AsyncReadExt + Unpin{ + pub(crate) async fn async_read(qos: QoS, retain: bool, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> + where + S: tokio::io::AsyncRead + Unpin, + { let (last_will_properties, last_will_properties_read_bytes) = LastWillProperties::async_read(stream).await?; let (topic, topic_read_bytes) = Box::::async_read(stream).await?; let (payload, payload_read_bytes) = Vec::::async_read(stream).await?; let total_read_bytes = last_will_properties_read_bytes + topic_read_bytes + payload_read_bytes; - Ok((Self { - qos, - retain, - last_will_properties, - topic, - payload, - }, total_read_bytes)) + Ok(( + Self { + qos, + retain, + last_will_properties, + topic, + payload, + }, + total_read_bytes, + )) } } @@ -78,4 +84,4 @@ impl WireLength for LastWill { self.topic.wire_len() + self.payload.wire_len() + property_len.variable_integer_len() + property_len } -} \ No newline at end of file +} diff --git a/mqrstt/src/packets/connect/last_will_properties.rs b/mqrstt/src/packets/connect/last_will_properties.rs index 9e7db09..0a580ba 100644 --- a/mqrstt/src/packets/connect/last_will_properties.rs +++ b/mqrstt/src/packets/connect/last_will_properties.rs @@ -8,6 +8,7 @@ use crate::packets::{ }; crate::packets::macros::define_properties!( + /// Last Will Properties LastWillProperties, WillDelayInterval, PayloadFormatIndicator, diff --git a/mqrstt/src/packets/connect/mod.rs b/mqrstt/src/packets/connect/mod.rs index 0e4a0c6..1674db2 100644 --- a/mqrstt/src/packets/connect/mod.rs +++ b/mqrstt/src/packets/connect/mod.rs @@ -10,22 +10,24 @@ pub use connect_properties::ConnectProperties; mod last_will; pub use last_will::LastWill; -use bytes::{Buf, BufMut, Bytes, BytesMut}; - use crate::packets::error::ReadError; use super::{ - error::{DeserializeError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, protocol_version::ProtocolVersion, VariableInteger, WireLength + error::{DeserializeError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + ProtocolVersion, VariableInteger, WireLength, }; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::AsyncReadExt; /// Connect packet send by the client to the server to initialize a connection. -/// +/// /// Variable Header /// - Protocol Name and Version: Identifies the MQTT protocol and version. /// - Connect Flags: Options like clean start, will flag, will QoS, will retain, password flag, and username flag. /// - Keep Alive Interval: Maximum time interval between messages. /// - Properties: Optional settings such as session expiry interval, receive maximum, maximum packet size, and topic alias maximum. -/// +/// /// Payload /// - Client Identifier: Unique ID for the client. /// - Will Message: Optional message sent if the client disconnects unexpectedly. @@ -108,7 +110,10 @@ impl PacketRead for Connect { } } -impl PacketAsyncRead for Connect where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for Connect +where + S: tokio::io::AsyncRead + Unpin, +{ async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { let mut total_read_bytes = 0; let expected_protocol = [0x00, 0x04, b'M', b'Q', b'T', b'T']; @@ -127,9 +132,9 @@ impl PacketAsyncRead for Connect where S: tokio::io::AsyncReadExt + Unpin let keep_alive = stream.read_u16().await?; // Add keep alive read bytes total_read_bytes += 2; - + let (connect_properties, prop_read_bytes) = ConnectProperties::async_read(stream).await?; - let (client_id, client_read_bytes) = Box::::async_read(stream).await?; + let (client_id, client_read_bytes) = Box::::async_read(stream).await?; total_read_bytes += prop_read_bytes + client_read_bytes; let last_will = if connect_flags.will_flag { @@ -141,14 +146,18 @@ impl PacketAsyncRead for Connect where S: tokio::io::AsyncReadExt + Unpin None }; - let (username, username_read_bytes) = if connect_flags.username { + let (username, username_read_bytes) = if connect_flags.username { let (username, username_read_bytes) = Box::::async_read(stream).await?; (Some(username), username_read_bytes) - } else { (None, 0) }; - let (password, password_read_bytes) = if connect_flags.password { + } else { + (None, 0) + }; + let (password, password_read_bytes) = if connect_flags.password { let (password, password_read_bytes) = Box::::async_read(stream).await?; (Some(password), password_read_bytes) - } else { (None, 0) }; + } else { + (None, 0) + }; total_read_bytes += username_read_bytes + password_read_bytes; diff --git a/mqrstt/src/packets/disconnect/mod.rs b/mqrstt/src/packets/disconnect/mod.rs index bb8573b..e5d34ee 100644 --- a/mqrstt/src/packets/disconnect/mod.rs +++ b/mqrstt/src/packets/disconnect/mod.rs @@ -5,9 +5,13 @@ mod reason_code; pub use reason_code::DisconnectReasonCode; use super::{ - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, VariableInteger, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, }; +/// The DISCONNECT Packet is the final packet. +/// The client sends this packet to the server to disconnect for example on calling [`crate::MqttClient::disconnect`]. +/// The server can send a disconnect packet to the client to indicate that the connection is being closed. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Disconnect { pub reason_code: DisconnectReasonCode, @@ -20,7 +24,13 @@ where { async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { if remaining_length == 0 { - Ok((Self { reason_code: DisconnectReasonCode::NormalDisconnection, properties: DisconnectProperties::default() }, 0)) + Ok(( + Self { + reason_code: DisconnectReasonCode::NormalDisconnection, + properties: DisconnectProperties::default(), + }, + 0, + )) } else { let (reason_code, reason_code_read_bytes) = DisconnectReasonCode::async_read(stream).await?; let (properties, properties_read_bytes) = DisconnectProperties::async_read(stream).await?; @@ -76,10 +86,7 @@ mod tests { properties: DisconnectProperties { session_expiry_interval: Some(123), reason_string: Some(Box::from("Some reason")), - user_properties: vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ], + user_properties: vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2"))], server_reference: Some(Box::from("Server reference")), }, reason_code: DisconnectReasonCode::NormalDisconnection, @@ -96,15 +103,9 @@ mod tests { assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); assert_eq!( read_packet.properties.user_properties, - vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ] - ); - assert_eq!( - read_packet.properties.server_reference, - Some(Box::from("Server reference")) + vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2")),] ); + assert_eq!(read_packet.properties.server_reference, Some(Box::from("Server reference"))); } #[test] @@ -114,10 +115,7 @@ mod tests { properties: DisconnectProperties { session_expiry_interval: Some(123), reason_string: Some(Box::from("Some reason")), - user_properties: vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ], + user_properties: vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2"))], server_reference: Some(Box::from("Server reference")), }, reason_code: DisconnectReasonCode::NormalDisconnection, @@ -131,15 +129,9 @@ mod tests { assert_eq!(read_packet.properties.reason_string, Some(Box::from("Some reason"))); assert_eq!( read_packet.properties.user_properties, - vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ] - ); - assert_eq!( - read_packet.properties.server_reference, - Some(Box::from("Server reference")) + vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2")),] ); + assert_eq!(read_packet.properties.server_reference, Some(Box::from("Server reference"))); } #[test] @@ -148,30 +140,20 @@ mod tests { let properties = DisconnectProperties { session_expiry_interval: Some(123), reason_string: Some(Box::from("Some reason")), - user_properties: vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ], + user_properties: vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2"))], server_reference: Some(Box::from("Server reference")), }; - + properties.write(&mut buf).unwrap(); - + let read_properties = DisconnectProperties::read(&mut buf.into()).unwrap(); - + assert_eq!(read_properties.session_expiry_interval, Some(123)); assert_eq!(read_properties.reason_string, Some(Box::from("Some reason"))); assert_eq!( read_properties.user_properties, - vec![ - (Box::from("key1"), Box::from("value1")), - (Box::from("key2"), Box::from("value2")), - ] - ); - assert_eq!( - read_properties.server_reference, - Some(Box::from("Server reference")) + vec![(Box::from("key1"), Box::from("value1")), (Box::from("key2"), Box::from("value2")),] ); + assert_eq!(read_properties.server_reference, Some(Box::from("Server reference"))); } } - diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index 7585b3a..53abcc8 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -13,7 +13,7 @@ pub enum WriteError { } #[derive(Error, Debug)] -pub enum ReadError{ +pub enum ReadError { #[error("{0}")] DeserializeError(#[from] DeserializeError), #[error("{0}")] @@ -39,10 +39,10 @@ pub enum DeserializeError { #[error("There is insufficient for {0} data ({1}) to take {2} bytes")] InsufficientData(&'static str, usize, usize), - + #[error("There is insufficient to read the protocol version.")] InsufficientDataForProtocolVersion, - + #[error("Read more data for the packet than indicated length")] ReadTooMuchData(&'static str, usize, usize), @@ -72,7 +72,7 @@ impl From for DeserializeError { } #[derive(Error, Clone, Debug)] -pub enum ReadBytes { +pub(crate) enum ReadBytes { #[error("Normal error")] Err(#[from] T), diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index d6a01c1..89c20ed 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -1,8 +1,11 @@ macro_rules! define_properties { - ($name:ident, $($prop_variant:ident),*) => { - $crate::packets::macros::properties_struct!(@ $name { $($prop_variant,)* } -> ()); + ($(#[$attr:meta])* $name:ident, $($prop_variant:ident),*) => { + $crate::packets::macros::properties_struct!(@ + $(#[$attr])* + $name { $($prop_variant,)* } -> () + ); - impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + Unpin { + impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncRead + Unpin { async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { let (len, length_variable_integer) = ::read_async_variable_integer(stream).await?; if len == 0 { @@ -43,54 +46,55 @@ macro_rules! define_properties { } macro_rules! properties_struct { - ( @ $name:ident { } -> ($($result:tt)*) ) => ( + ( @ $(#[$attr:meta])* $name:ident { } -> ($($result:tt)*) ) => ( + // $(#[$attr])* #[derive(Debug, PartialEq, Eq, Clone, Hash, Default)] pub struct $name { $($result)* } ); - ( @ $name:ident { PayloadFormatIndicator, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { PayloadFormatIndicator, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.2 Payload Format Indicator /// 1 (0x01) Byte, Identifier of the Payload Format Indicator. pub payload_format_indicator: Option, )); ); - ( @ $name:ident { MessageExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { MessageExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.3 Message Expiry Interval /// 2 (0x02) Byte, Identifier of the Message Expiry Interval. pub message_expiry_interval: Option, )); ); - ( @ $name:ident { ContentType, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ContentType, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.9 Content Type /// 3 (0x03) Identifier of the Content Type pub content_type: Option>, )); ); - ( @ $name:ident { ResponseTopic, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ResponseTopic, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.5 Response Topic /// 8 (0x08) Byte, Identifier of the Response Topic. pub response_topic: Option>, )); ); - ( @ $name:ident { CorrelationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { CorrelationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.6 Correlation Data /// 9 (0x09) Byte, Identifier of the Correlation Data. pub correlation_data: Option>, )); ); - ( @ $name:ident { ListSubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ListSubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.8 Subscription Identifier /// 11 (0x0B), Identifier of the Subscription Identifier. @@ -98,183 +102,183 @@ macro_rules! properties_struct { pub subscription_identifiers: Vec, )); ); - ( @ $name:ident { SubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { SubscriptionIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.8 Subscription Identifier /// 11 (0x0B), Identifier of the Subscription Identifier. pub subscription_identifier: Option, )); ); - ( @ $name:ident { SessionExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { SessionExpiryInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.2 Session Expiry Interval /// 17 (0x11) Byte Identifier of the Session Expiry Interval pub session_expiry_interval: Option, )); ); - ( @ $name:ident { AssignedClientIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { AssignedClientIdentifier, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.7 Assigned Client Identifier /// 18 (0x12) Byte, Identifier of the Assigned Client Identifier. - pub assigned_client_id: Option>, + pub assigned_client_identifier: Option>, )); ); - ( @ $name:ident { ServerKeepAlive, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ServerKeepAlive, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.14 Server Keep Alive /// 19 (0x13) Byte, Identifier of the Server Keep Alive pub server_keep_alive: Option, )); ); - ( @ $name:ident { AuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { AuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.17 Authentication Method /// 21 (0x15) Byte, Identifier of the Authentication Method pub authentication_method: Option>, )); ); - ( @ $name:ident { AuthenticationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { AuthenticationData, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.18 Authentication Data /// 22 (0x16) Byte, Identifier of the Authentication Data pub authentication_data: Option>, )); ); - ( @ $name:ident { RequestProblemInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { RequestProblemInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.1.2.11.7 Request Problem Information /// 23 (0x17) Byte, Identifier of the Request Problem Information pub request_problem_information: Option, )); ); - ( @ $name:ident { WillDelayInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { WillDelayInterval, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.1.3.2.2 Request Problem Information /// 24 (0x18) Byte, Identifier of the Will Delay Interval. pub will_delay_interval: Option, )); ); - ( @ $name:ident { RequestResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { RequestResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.1.2.11.6 Request Response Information /// 25 (0x19) Byte, Identifier of the Request Response Information pub request_response_information: Option, )); ); - ( @ $name:ident { ResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ResponseInformation, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.15 Response Information /// 26 (0x1A) Byte, Identifier of the Response Information. pub response_info: Option>, )); ); - ( @ $name:ident { ServerReference, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ServerReference, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.16 Server Reference /// 28 (0x1C) Byte, Identifier of the Server Reference pub server_reference: Option>, )); ); - ( @ $name:ident { ReasonString, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ReasonString, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.9 Reason String /// 31 (0x1F) Byte Identifier of the Reason String. pub reason_string: Option>, )); ); - ( @ $name:ident { ReceiveMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { ReceiveMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.3 Receive Maximum /// 33 (0x21) Byte, Identifier of the Receive Maximum pub receive_maximum: Option, )); ); - ( @ $name:ident { TopicAliasMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { TopicAliasMaximum, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.8 Topic Alias Maximum /// 34 (0x22) Byte, Identifier of the Topic Alias Maximum. pub topic_alias_maximum: Option, )); ); - ( @ $name:ident { TopicAlias, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { TopicAlias, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.3.2.3.4 Topic Alias /// 35 (0x23) Byte, Identifier of the Topic Alias. pub topic_alias: Option, )); ); - ( @ $name:ident { MaximumQos, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { MaximumQos, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.4 Maximum QoS /// 36 (0x24) Byte, Identifier of the Maximum QoS. pub maximum_qos: Option<$crate::packets::QoS>, )); ); - ( @ $name:ident { RetainAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { RetainAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.5 Retain Available /// 37 (0x25) Byte, Identifier of Retain Available. pub retain_available: Option, )); ); - ( @ $name:ident { UserProperty, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { UserProperty, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.10 User Property /// 38 (0x26) Byte, Identifier of User Property. pub user_properties: Vec<(Box, Box)>, )); ); - ( @ $name:ident { MaximumPacketSize, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { MaximumPacketSize, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.6 Maximum Packet Size /// 39 (0x27) Byte, Identifier of the Maximum Packet Size. pub maximum_packet_size: Option, )); ); - ( @ $name:ident { WildcardSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { WildcardSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.11 Wildcard Subscription Available /// 40 (0x28) Byte, Identifier of Wildcard Subscription Available. pub wildcards_available: Option, )); ); - ( @ $name:ident { SubscriptionIdentifierAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { SubscriptionIdentifierAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.12 Subscription Identifiers Available /// 41 (0x29) Byte, Identifier of Subscription Identifier Available. pub subscription_ids_available: Option, )); ); - ( @ $name:ident { SharedSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::properties_struct!(@ $name { $($rest)* } -> ( + ( @ $(#[$attr:meta])* $name:ident { SharedSubscriptionAvailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::properties_struct!(@ $(#[$attr])* $name { $($rest)* } -> ( $($result)* /// 3.2.2.3.13 Shared Subscription Available /// 42 (0x2A) Byte, Identifier of Shared Subscription Available. pub shared_subscription_available: Option, )); ); - ( @ $name:ident { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $(#[$attr:meta])* $name:ident { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( compile_error!(concat!("Unknown property: ", stringify!($unknown))); ); } @@ -345,14 +349,14 @@ macro_rules! properties_read_match_branch_body { $properties.session_expiry_interval = Some(prop_body); }}; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::AssignedClientIdentifier) => {{ - if $properties.assigned_client_id.is_some() { + if $properties.assigned_client_identifier.is_some() { return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( PropertyType::AssignedClientIdentifier, ))); } let (prop_body, read_bytes) = Box::::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.assigned_client_id = Some(prop_body); + $properties.assigned_client_identifier = Some(prop_body); }}; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::ServerKeepAlive) => {{ if $properties.server_keep_alive.is_some() { @@ -581,7 +585,7 @@ macro_rules! properties_wire_length { } }; ($self:ident, $len:ident, PropertyType::AssignedClientIdentifier) => { - if let Some(client_id) = $self.assigned_client_id.as_ref() { + if let Some(client_id) = $self.assigned_client_identifier.as_ref() { $len += 1 + client_id.wire_len(); } }; diff --git a/mqrstt/src/packets/macros/reason_code_macros.rs b/mqrstt/src/packets/macros/reason_code_macros.rs index abb1f37..b56a669 100644 --- a/mqrstt/src/packets/macros/reason_code_macros.rs +++ b/mqrstt/src/packets/macros/reason_code_macros.rs @@ -1,12 +1,13 @@ macro_rules! reason_code { ($name:ident, $($code:ident),*) => { + use tokio::io::AsyncReadExt; #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum $name { #[default] $($code),* } - impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncReadExt + std::marker::Unpin{ + impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncRead + std::marker::Unpin{ async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { let input = stream.read_u8().await?; let res = $crate::packets::macros::reason_code_match!(@ $name, input, { @@ -504,7 +505,7 @@ macro_rules! reason_code_match_write{ $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( $($result)* $name::PacketIdentifierNotFound => 0x92, - + )) ); ( @ $name:ident, $buf:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( @@ -611,4 +612,4 @@ macro_rules! reason_code_match_write{ pub(crate) use reason_code; pub(crate) use reason_code_match; -pub(crate) use reason_code_match_write; \ No newline at end of file +pub(crate) use reason_code_match_write; diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 31d3707..c77969f 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -1,5 +1,5 @@ pub mod error; -pub mod mqtt_trait; +pub(crate) mod mqtt_trait; mod macros; @@ -42,7 +42,7 @@ use std::fmt::Display; use self::error::{DeserializeError, SerializeError}; use self::mqtt_trait::{PacketRead, PacketWrite, WireLength}; -// ==================== Packets ==================== +/// Enum to bundle the different MQTT packets. #[derive(Debug, Clone, PartialEq, Eq)] pub enum Packet { Connect(Connect), diff --git a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs index 76b9a82..503c4f5 100644 --- a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs +++ b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs @@ -1,8 +1,7 @@ -use bytes::{BufMut, Buf, Bytes, BytesMut}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::packets::mqtt_trait::{MqttRead, MqttAsyncRead, MqttWrite, WireLength}; use crate::packets::error::{DeserializeError, ReadError, SerializeError}; - +use crate::packets::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, WireLength}; impl MqttRead for Box { #[inline] @@ -16,7 +15,10 @@ impl MqttRead for Box { } } -impl MqttAsyncRead for Box where S: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for Box +where + S: tokio::io::AsyncRead + std::marker::Unpin, +{ async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { let (content, read_bytes) = Vec::async_read(stream).await?; match String::from_utf8(content) { @@ -68,7 +70,10 @@ impl MqttRead for String { } } -impl MqttAsyncRead for String where T: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for String +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { let (content, read_bytes) = Bytes::async_read(buf).await?; match String::from_utf8(content.to_vec()) { @@ -110,7 +115,10 @@ impl MqttRead for Bytes { Ok(buf.split_to(len)) } } -impl MqttAsyncRead for Bytes where S: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for Bytes +where + S: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { let size = stream.read_u16().await? as usize; // let mut data = BytesMut::with_capacity(size); @@ -148,7 +156,7 @@ impl MqttRead for Vec { Ok(buf.split_to(len).into()) } } -impl MqttWrite for Vec { +impl MqttWrite for Vec { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { buf.put_u16(self.len() as u16); @@ -163,7 +171,10 @@ impl WireLength for Vec { self.len() + 2 } } -impl MqttAsyncRead for Vec where S: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for Vec +where + S: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { let size = stream.read_u16().await? as usize; // let mut data = BytesMut::with_capacity(size); @@ -174,7 +185,6 @@ impl MqttAsyncRead for Vec where S: tokio::io::AsyncReadExt + std::mar } } - impl MqttRead for bool { fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { @@ -188,7 +198,10 @@ impl MqttRead for bool { } } } -impl MqttAsyncRead for bool where T: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for bool +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { match buf.read_u8().await? { 0 => Ok((false, 1)), @@ -219,7 +232,10 @@ impl MqttRead for u8 { Ok(buf.get_u8()) } } -impl MqttAsyncRead for u8 where T: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for u8 +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { Ok((buf.read_u8().await?, 1)) } @@ -234,7 +250,10 @@ impl MqttRead for u16 { Ok(buf.get_u16()) } } -impl MqttAsyncRead for u16 where T: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for u16 +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { Ok((buf.read_u16().await?, 2)) } @@ -256,7 +275,10 @@ impl MqttRead for u32 { Ok(buf.get_u32()) } } -impl MqttAsyncRead for u32 where T: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for u32 +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { Ok((buf.read_u32().await?, 4)) } diff --git a/mqrstt/src/packets/primitive/mod.rs b/mqrstt/src/packets/primitive/mod.rs index 1b098a5..4f9d756 100644 --- a/mqrstt/src/packets/primitive/mod.rs +++ b/mqrstt/src/packets/primitive/mod.rs @@ -1,7 +1,8 @@ mod fixed_header; pub(crate) use fixed_header::FixedHeader; -pub mod protocol_version; +mod protocol_version; +pub use protocol_version::ProtocolVersion; mod property_type; pub(crate) use property_type::PropertyType; diff --git a/mqrstt/src/packets/primitive/protocol_version.rs b/mqrstt/src/packets/primitive/protocol_version.rs index 7830cd2..84ca01a 100644 --- a/mqrstt/src/packets/primitive/protocol_version.rs +++ b/mqrstt/src/packets/primitive/protocol_version.rs @@ -1,8 +1,15 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::packets::{error::{DeserializeError, ReadError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}}; +use tokio::io::AsyncReadExt; -/// Protocol version +use crate::packets::{ + error::{DeserializeError, ReadError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, +}; + +/// Protocol version of the MQTT connection +/// +/// This client only supports MQTT v5.0. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] pub enum ProtocolVersion { V5, @@ -30,7 +37,10 @@ impl MqttRead for ProtocolVersion { } } -impl MqttAsyncRead for ProtocolVersion where S: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for ProtocolVersion +where + S: tokio::io::AsyncRead + std::marker::Unpin, +{ async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { match stream.read_u8().await { Ok(5) => Ok((ProtocolVersion::V5, 1)), diff --git a/mqrstt/src/packets/puback/mod.rs b/mqrstt/src/packets/puback/mod.rs index a6e35cc..e91a56a 100644 --- a/mqrstt/src/packets/puback/mod.rs +++ b/mqrstt/src/packets/puback/mod.rs @@ -9,6 +9,8 @@ use super::{ PacketType, PropertyType, VariableInteger, }; +/// The PUBACK Packet is the response to a PUBLISH Packet with QoS 1. +/// Both the server and client can send a PUBACK packet. #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct PubAck { pub packet_identifier: u16, @@ -112,7 +114,12 @@ impl WireLength for PubAck { } } -crate::packets::macros::define_properties!(PubAckProperties, ReasonString, UserProperty); +crate::packets::macros::define_properties!( + /// PubAck Properties + PubAckProperties, + ReasonString, + UserProperty +); impl MqttRead for PubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 2c4f07f..24355aa 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -1,19 +1,21 @@ mod reason_code; pub use reason_code::PubCompReasonCode; - mod properties; pub use properties::PubCompProperties; - - -use bytes::BufMut; - use super::{ error::DeserializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, }; +use bytes::BufMut; +use tokio::io::AsyncReadExt; +/// The PUBCOMP Packet is the response to a PUBLISH Packet with QoS 2. +/// It is the fourth and final packet of the QoS 2 protocol exchange. +/// The user of the client application does not have to send this packet, it is handled internally by the client. +/// +/// Both the client and server can send this packet. #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct PubComp { pub packet_identifier: u16, @@ -58,16 +60,22 @@ impl PacketRead for PubComp { } } -impl PacketAsyncRead for PubComp where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for PubComp +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let packet_identifier = stream.read_u16().await?; if remaining_length == 2 { - return Ok((Self { - packet_identifier, - reason_code: PubCompReasonCode::Success, - properties: PubCompProperties::default(), - }, 2)); + return Ok(( + Self { + packet_identifier, + reason_code: PubCompReasonCode::Success, + properties: PubCompProperties::default(), + }, + 2, + )); } // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { @@ -79,11 +87,14 @@ impl PacketAsyncRead for PubComp where S: tokio::io::AsyncReadExt + Unpin assert_eq!(2 + reason_code_read_bytes + properties_read_bytes, remaining_length); - Ok((Self { - packet_identifier, - reason_code, - properties, - }, 2 + reason_code_read_bytes + properties_read_bytes)) + Ok(( + Self { + packet_identifier, + reason_code, + properties, + }, + 2 + reason_code_read_bytes + properties_read_bytes, + )) } } } @@ -116,11 +127,12 @@ impl WireLength for PubComp { } } - #[cfg(test)] mod tests { use crate::packets::{ - mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubcomp::{PubComp, PubCompProperties}, PropertyType, PubCompReasonCode, VariableInteger + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + pubcomp::{PubComp, PubCompProperties}, + PropertyType, PubCompReasonCode, VariableInteger, }; use bytes::{BufMut, Bytes, BytesMut}; diff --git a/mqrstt/src/packets/pubcomp/properties.rs b/mqrstt/src/packets/pubcomp/properties.rs index f508600..ff5fb49 100644 --- a/mqrstt/src/packets/pubcomp/properties.rs +++ b/mqrstt/src/packets/pubcomp/properties.rs @@ -4,7 +4,12 @@ use crate::packets::{ PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!(PubCompProperties, ReasonString, UserProperty); +crate::packets::macros::define_properties!( + /// PubComp Properties + PubCompProperties, + ReasonString, + UserProperty +); impl MqttRead for PubCompProperties { fn read(buf: &mut bytes::Bytes) -> Result { diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index 2570309..f95c852 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -15,6 +15,14 @@ use super::{ QoS, }; +/// The PUBLISH Packet is used to send data from either side of the connection. +/// This packet is handed to the [`crate::AsyncEventHandler`] to be handled by the user. +/// +/// The following flow is determined by the QoS level used in PUBLISH Packet. +/// QoS 0: Send and forget, no deliviery garantee. +/// QoS 1: Send and acknowledgement, uised to ensure that the packet is delivered at least once. +/// QoS 2: Send and 2-step acknowledgement, used to ensure that the packet is delivered only once. +/// The packet can be send using for example [`crate::MqttClient::publish`] or [`crate::MqttClient::publish_with_properties`]. #[derive(Debug, Default, Clone, PartialEq, Eq)] pub struct Publish { /// 3.3.1.1 dup diff --git a/mqrstt/src/packets/publish/properties.rs b/mqrstt/src/packets/publish/properties.rs index dac9f39..cd78081 100644 --- a/mqrstt/src/packets/publish/properties.rs +++ b/mqrstt/src/packets/publish/properties.rs @@ -9,6 +9,7 @@ use crate::packets::{ }; crate::packets::macros::define_properties!( + /// Publish Properties PublishProperties, PayloadFormatIndicator, MessageExpiryInterval, diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 8edadf4..639fcf6 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -1,4 +1,3 @@ - mod properties; pub use properties::PubRecProperties; @@ -11,9 +10,14 @@ use tokio::io::AsyncReadExt; use super::{ error::DeserializeError, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, PacketAsyncRead, + mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + PacketAsyncRead, }; +/// The [`PubRec`] (Publish Received) packet is part of the acknowledgment flow for a [`crate::packets::Publish`] with QoS 2. +/// +/// It means that the Publish has been received, the flow will continue with the [`crate::packets::pubrel::PubRel`] +/// packet and then the [`crate::packets::pubcomp::PubComp`] packet. #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct PubRec { pub packet_identifier: u16, @@ -57,31 +61,39 @@ impl PacketRead for PubRec { } } -impl PacketAsyncRead for PubRec where S: tokio::io::AsyncRead + Unpin{ +impl PacketAsyncRead for PubRec +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let mut total_read_bytes = 0; let packet_identifier = stream.read_u16().await?; total_read_bytes += 2; if remaining_length == 2 { - return Ok((Self { - packet_identifier, - reason_code: PubRecReasonCode::Success, - properties: PubRecProperties::default(), - }, total_read_bytes)); + return Ok(( + Self { + packet_identifier, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties::default(), + }, + total_read_bytes, + )); } let (reason_code, reason_code_read_bytes) = PubRecReasonCode::async_read(stream).await?; let (properties, properties_read_bytes) = PubRecProperties::async_read(stream).await?; - - total_read_bytes += reason_code_read_bytes + properties_read_bytes; - Ok((Self { - packet_identifier, - properties, - reason_code, - }, total_read_bytes)) + total_read_bytes += reason_code_read_bytes + properties_read_bytes; + Ok(( + Self { + packet_identifier, + properties, + reason_code, + }, + total_read_bytes, + )) } } } @@ -117,7 +129,9 @@ impl WireLength for PubRec { #[cfg(test)] mod tests { use crate::packets::{ - mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, pubrec::{PubRec, PubRecProperties}, PropertyType, PubRecReasonCode, VariableInteger + mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, + pubrec::{PubRec, PubRecProperties}, + PropertyType, PubRecReasonCode, VariableInteger, }; use bytes::{BufMut, Bytes, BytesMut}; diff --git a/mqrstt/src/packets/pubrec/properties.rs b/mqrstt/src/packets/pubrec/properties.rs index 146a149..3f3816f 100644 --- a/mqrstt/src/packets/pubrec/properties.rs +++ b/mqrstt/src/packets/pubrec/properties.rs @@ -4,7 +4,12 @@ use crate::packets::{ PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!(PubRecProperties, ReasonString, UserProperty); +crate::packets::macros::define_properties!( + /// PubRec Properties + PubRecProperties, + ReasonString, + UserProperty +); impl MqttRead for PubRecProperties { fn read(buf: &mut bytes::Bytes) -> Result { diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index 974b026..c7beff6 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -11,6 +11,9 @@ use super::{ mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, }; +/// The [`PubRel`] (Publish Release) packet acknowledges the reception of a [`crate::packets::PubRec`] Packet. +/// +/// This user does not need to send this message, it is handled internally by the client. #[derive(Debug, PartialEq, Eq, Clone, Hash)] pub struct PubRel { pub packet_identifier: u16, diff --git a/mqrstt/src/packets/pubrel/properties.rs b/mqrstt/src/packets/pubrel/properties.rs index f4a97c8..09d36d3 100644 --- a/mqrstt/src/packets/pubrel/properties.rs +++ b/mqrstt/src/packets/pubrel/properties.rs @@ -4,7 +4,12 @@ use crate::packets::{ PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!(PubRelProperties, ReasonString, UserProperty); +crate::packets::macros::define_properties!( + /// PubRel Properties + PubRelProperties, + ReasonString, + UserProperty +); impl PubRelProperties { pub fn is_empty(&self) -> bool { diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index 718909a..b34c728 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -4,15 +4,15 @@ pub use properties::SubAckProperties; mod reason_code; pub use reason_code::SubAckReasonCode; -use bytes::BufMut; - use super::{ error::SerializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, - }; +use bytes::BufMut; +use tokio::io::AsyncReadExt; -/// 3.9 SUBACK – Subscribe acknowledgement +/// SubAck packet is sent by the server in response to a [`crate::packets::Subscribe`] packet. +/// /// A SUBACK packet is sent by the Server to the Client to confirm receipt and processing of a SUBSCRIBE packet. /// A SUBACK packet contains a list of Reason Codes, that specify the maximum QoS level that was granted or the error which was found for each Subscription that was requested by the SUBSCRIBE. #[derive(Debug, Default, PartialEq, Eq, Clone)] @@ -28,11 +28,11 @@ impl PacketRead for SubAck { let properties = SubAckProperties::read(&mut buf)?; dbg!("aa"); - + let mut reason_codes = vec![]; loop { let reason_code = SubAckReasonCode::read(&mut buf)?; - + dbg!(reason_code); reason_codes.push(reason_code); @@ -49,7 +49,10 @@ impl PacketRead for SubAck { } } -impl PacketAsyncRead for SubAck where S: tokio::io::AsyncReadExt + Unpin { +impl PacketAsyncRead for SubAck +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let mut total_read_bytes = 0; @@ -61,17 +64,20 @@ impl PacketAsyncRead for SubAck where S: tokio::io::AsyncReadExt + Unpin { let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; total_read_bytes += reason_code_read_bytes; reason_codes.push(reason_code); - + if remaining_length == total_read_bytes { break; } } - - Ok((Self { - packet_identifier, - properties, - reason_codes, - }, total_read_bytes)) + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) } } } diff --git a/mqrstt/src/packets/suback/properties.rs b/mqrstt/src/packets/suback/properties.rs index b88c5b8..be7ac8b 100644 --- a/mqrstt/src/packets/suback/properties.rs +++ b/mqrstt/src/packets/suback/properties.rs @@ -6,7 +6,12 @@ use crate::packets::{ use crate::packets::primitive::VariableInteger; -crate::packets::macros::define_properties!(SubAckProperties, SubscriptionIdentifier, UserProperty); +crate::packets::macros::define_properties!( + /// SubAck Properties + SubAckProperties, + SubscriptionIdentifier, + UserProperty +); impl MqttRead for SubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { diff --git a/mqrstt/src/packets/subscribe/mod.rs b/mqrstt/src/packets/subscribe/mod.rs index 997b105..1b9deac 100644 --- a/mqrstt/src/packets/subscribe/mod.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -12,6 +12,10 @@ use super::{ }; use bytes::{Buf, BufMut}; +/// Used to subscribe to topic(s). +/// +/// Multiple topics can be subscribed from at once. +/// For convenience [`SubscribeTopics`] is provided. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Subscribe { pub packet_identifier: u16, @@ -205,10 +209,18 @@ impl MqttWrite for SubscriptionOptions { } } +// Please describe the retain handling type, what is it used for + +/// Controls how retained messages are handled +/// +/// Used when a new subscription is established. Here are the three options for retain handling: #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RetainHandling { + /// Send Retained Messages at Subscription: This is the default behavior. When a client subscribes to a topic, the broker sends any retained messages for that topic immediately. ZERO, + /// Send Retained Messages Only for New Subscriptions: Retained messages are sent only if the subscription did not previously exist. ONE, + /// Do Not Send Retained Messages: Retained messages are not sent when the subscription is established TWO, } @@ -298,7 +310,7 @@ where macro_rules! impl_subscription { ($t:ty) => { - impl From<$t> for Subscription { + impl From<$t> for SubscribeTopics { #[inline] fn from(value: $t) -> Self { Self(vec![IntoSingleSubscription::into(value)]) @@ -307,19 +319,19 @@ macro_rules! impl_subscription { }; } -pub struct Subscription(pub Vec<(Box, SubscriptionOptions)>); +pub struct SubscribeTopics(pub Vec<(Box, SubscriptionOptions)>); // -------------------- Simple types -------------------- impl_subscription!(&str); impl_subscription!(&String); impl_subscription!(String); impl_subscription!(Box); -impl From<&(&str, QoS)> for Subscription { +impl From<&(&str, QoS)> for SubscribeTopics { fn from(value: &(&str, QoS)) -> Self { Self(vec![IntoSingleSubscription::into(value)]) } } -impl From<(T, QoS)> for Subscription +impl From<(T, QoS)> for SubscribeTopics where (T, QoS): IntoSingleSubscription, { @@ -327,7 +339,7 @@ where Self(vec![IntoSingleSubscription::into(value)]) } } -impl From<(T, SubscriptionOptions)> for Subscription +impl From<(T, SubscriptionOptions)> for SubscribeTopics where (T, SubscriptionOptions): IntoSingleSubscription, { @@ -336,7 +348,7 @@ where } } // -------------------- Arrays -------------------- -impl From<&[T; S]> for Subscription +impl From<&[T; S]> for SubscribeTopics where for<'any> &'any T: IntoSingleSubscription, { @@ -345,7 +357,7 @@ where } } // -------------------- Slices -------------------- -impl From<&[T]> for Subscription +impl From<&[T]> for SubscribeTopics where for<'any> &'any T: IntoSingleSubscription, { @@ -354,7 +366,7 @@ where } } // -------------------- Vecs -------------------- -impl From> for Subscription +impl From> for SubscribeTopics where T: IntoSingleSubscription, { @@ -362,7 +374,7 @@ where Self(value.into_iter().map(|val| IntoSingleSubscription::into(val)).collect()) } } -impl From<&Vec> for Subscription +impl From<&Vec> for SubscribeTopics where for<'any> &'any T: IntoSingleSubscription, { diff --git a/mqrstt/src/packets/subscribe/properties.rs b/mqrstt/src/packets/subscribe/properties.rs index 27559cc..6703e99 100644 --- a/mqrstt/src/packets/subscribe/properties.rs +++ b/mqrstt/src/packets/subscribe/properties.rs @@ -4,7 +4,12 @@ use crate::packets::{ PacketType, PropertyType, VariableInteger, }; -crate::packets::macros::define_properties!(SubscribeProperties, SubscriptionIdentifier, UserProperty); +crate::packets::macros::define_properties!( + /// Subscribe Properties + SubscribeProperties, + SubscriptionIdentifier, + UserProperty +); impl MqttRead for SubscribeProperties { fn read(buf: &mut bytes::Bytes) -> Result { diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs index dd4a6bf..140c967 100644 --- a/mqrstt/src/packets/unsuback/mod.rs +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -14,6 +14,7 @@ use super::error::SerializeError; use super::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite}; use super::PacketAsyncRead; +/// UnsubAck packet is sent by the server in response to an [`crate::packets::Unsubscribe`] packet. #[derive(Debug, Default, PartialEq, Eq, Clone)] pub struct UnsubAck { pub packet_identifier: u16, diff --git a/mqrstt/src/packets/unsuback/properties.rs b/mqrstt/src/packets/unsuback/properties.rs index 138e85a..85dcab0 100644 --- a/mqrstt/src/packets/unsuback/properties.rs +++ b/mqrstt/src/packets/unsuback/properties.rs @@ -2,7 +2,12 @@ use crate::packets::error::DeserializeError; use crate::packets::mqtt_trait::{MqttRead, MqttWrite, WireLength}; use crate::packets::{PacketType, PropertyType, VariableInteger}; -crate::packets::macros::define_properties!(UnsubAckProperties, ReasonString, UserProperty); +crate::packets::macros::define_properties!( + /// UnsubAck Properties + UnsubAckProperties, + ReasonString, + UserProperty +); impl MqttRead for UnsubAckProperties { fn read(buf: &mut bytes::Bytes) -> Result { diff --git a/mqrstt/src/packets/unsubscribe/mod.rs b/mqrstt/src/packets/unsubscribe/mod.rs index 23eb72a..b198dad 100644 --- a/mqrstt/src/packets/unsubscribe/mod.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -5,13 +5,14 @@ use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use crate::packets::mqtt_trait::MqttAsyncRead; -use super::VariableInteger; use super::mqtt_trait::{MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; -use bytes::BufMut; +use super::VariableInteger; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::AsyncReadExt; #[derive(Debug, Clone, PartialEq, Eq)] -/// Used to unsubscribe from topics. -/// +/// Used to unsubscribe from topic(s). +/// /// Multiple topics can be unsubscribed from at once. /// For convenience [`UnsubscribeTopics`] is provided. pub struct Unsubscribe { @@ -53,7 +54,10 @@ impl PacketRead for Unsubscribe { } } -impl PacketAsyncRead for Unsubscribe where S: tokio::io::AsyncReadExt + Unpin{ +impl PacketAsyncRead for Unsubscribe +where + S: tokio::io::AsyncRead + Unpin, +{ fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { let mut total_read_bytes = 0; @@ -73,11 +77,14 @@ impl PacketAsyncRead for Unsubscribe where S: tokio::io::AsyncReadExt + Un } } - Ok((Self { - packet_identifier, - properties, - topics, - }, total_read_bytes)) + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) } } } @@ -207,19 +214,6 @@ where } } -// impl From<&[&T]> for UnsubscribeTopics -// where -// SingleUnsubscribeTopic: for<'any> From<&'any T>, -// { -// fn from(value: &[&T]) -> Self { -// Self( -// value -// .iter() -// .map(|val| SingleUnsubscribeTopic::from(val).0) -// .collect(), -// ) -// } -// } // -------------------- Vecs -------------------- impl From> for UnsubscribeTopics where diff --git a/mqrstt/src/packets/unsubscribe/properties.rs b/mqrstt/src/packets/unsubscribe/properties.rs index e7c628f..510e96d 100644 --- a/mqrstt/src/packets/unsubscribe/properties.rs +++ b/mqrstt/src/packets/unsubscribe/properties.rs @@ -6,7 +6,11 @@ use crate::packets::{ PacketType, PropertyType, }; -crate::packets::macros::define_properties!(UnsubscribeProperties, UserProperty); +crate::packets::macros::define_properties!( + /// Unsubscribe Properties + UnsubscribeProperties, + UserProperty +); impl MqttRead for UnsubscribeProperties { fn read(buf: &mut bytes::Bytes) -> Result { diff --git a/mqrstt/src/smol/network.rs b/mqrstt/src/smol/network.rs index 6497dcd..2799b15 100644 --- a/mqrstt/src/smol/network.rs +++ b/mqrstt/src/smol/network.rs @@ -65,7 +65,7 @@ where S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, { /// Initializes an MQTT connection with the provided configuration an stream - pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { + pub(crate) async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; self.last_network_action = Instant::now(); diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index 2af3b0f..e004ff8 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -16,7 +16,7 @@ pub fn connack_case() -> (&'static [u8], Packet) { maximum_qos: None, retain_available: Some(true), maximum_packet_size: Some(1048576), - assigned_client_id: None, + assigned_client_identifier: None, topic_alias_maximum: Some(65535), reason_string: None, user_properties: vec![], @@ -202,7 +202,7 @@ pub fn publish_packet_4() -> Packet { } pub fn create_subscribe_packet(packet_identifier: u16) -> Packet { - let subscription: Subscription = "test/topic".into(); + let subscription: SubscribeTopics = "test/topic".into(); let sub = Subscribe::new(packet_identifier, subscription.0); Packet::Subscribe(sub) } From e874911416c6972723fae7e372cac9bee1e00c41 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Sun, 24 Nov 2024 23:30:33 +0100 Subject: [PATCH 13/45] More docs --- mqrstt/benches/bench_main.rs | 2 +- mqrstt/benches/benchmarks/tokio.rs | 4 ++-- mqrstt/src/client.rs | 6 ++++-- mqrstt/src/event_handlers.rs | 18 ++++++++++++++---- mqrstt/src/lib.rs | 3 +++ mqrstt/src/packets/primitive/fixed_header.rs | 20 +++++++++++--------- 6 files changed, 35 insertions(+), 18 deletions(-) diff --git a/mqrstt/benches/bench_main.rs b/mqrstt/benches/bench_main.rs index 9d7f5e4..32d301e 100644 --- a/mqrstt/benches/bench_main.rs +++ b/mqrstt/benches/bench_main.rs @@ -4,5 +4,5 @@ mod benchmarks; criterion_main! { benchmarks::tokio::tokio_concurrent, - benchmarks::tokio::tokio_synchronous, + benchmarks::tokio::tokio_sequential, } diff --git a/mqrstt/benches/benchmarks/tokio.rs b/mqrstt/benches/benchmarks/tokio.rs index cdd1950..74ce0b9 100644 --- a/mqrstt/benches/benchmarks/tokio.rs +++ b/mqrstt/benches/benchmarks/tokio.rs @@ -210,7 +210,7 @@ fn tokio_concurrent_benchmarks(c: &mut Criterion) { }); } -fn tokio_synchronous_benchmarks(c: &mut Criterion) { +fn tokio_sequential_benchmarks(c: &mut Criterion) { let mut group = c.benchmark_group("Tokio sequential"); group.sample_size(30); group.measurement_time(Duration::from_secs(120)); @@ -282,4 +282,4 @@ fn tokio_synchronous_benchmarks(c: &mut Criterion) { } criterion_group!(tokio_concurrent, tokio_concurrent_benchmarks); -criterion_group!(tokio_synchronous, tokio_synchronous_benchmarks); +criterion_group!(tokio_sequential, tokio_sequential_benchmarks); diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index d9f4335..0a96e44 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -22,9 +22,11 @@ use crate::{ }; #[derive(Debug, Clone)] -/// A Clonable client that can be used to send MQTT messages +/// A Clonable client that can be used to send MQTT messages. /// -/// This object is never self constructed but is a obtained by calling the builder functions on [`crate::NetworkBuilder`] +/// This object can be obtained by calling the builder functions on [`crate::NetworkBuilder`] +/// +/// This client should be used in combindation with a handler [`crate::AsyncEventHandler`] or [`crate::AsyncEventHandlerMut`] to handle incoming messages. pub struct MqttClient { /// Provides this client with an available packet id or waits on it. available_packet_ids_r: Receiver, diff --git a/mqrstt/src/event_handlers.rs b/mqrstt/src/event_handlers.rs index f8ef9c9..0941095 100644 --- a/mqrstt/src/event_handlers.rs +++ b/mqrstt/src/event_handlers.rs @@ -4,10 +4,12 @@ use futures::Future; use crate::packets::Packet; -/// Handlers are used to deal with packets before they are further processed (acked) -/// This guarantees that the end user has handlded the packet. -/// Trait for async mutable access to handler. -/// Usefull when you have a single handler +/// Handlers are used to deal with packets before they are acknowledged to the broker. +/// This guarantees that the end user has handlded the packet. Additionally, handlers only deal with incoming packets. +/// +/// This handler can be used to handle message sequentialy or concurrently. +/// +/// To send messages look at [`crate::MqttClient`] pub trait AsyncEventHandler { fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync; } @@ -29,16 +31,24 @@ where ::handle(&self, incoming_packet) } } +/// This is a simple no operation handler. impl AsyncEventHandler for () { fn handle(&self, _: Packet) -> impl Future + Send + Sync { async {} } } +/// Handlers are used to deal with packets before they are acknowledged to the broker. +/// This guarantees that the end user has handlded the packet. Additionally, handlers only deal with incoming packets. +/// +/// This handler can be used to handle message sequentialy. +/// +/// To send messages look at [`crate::MqttClient`] pub trait AsyncEventHandlerMut { fn handle(&mut self, incoming_packet: Packet) -> impl Future + Send + Sync; } +/// This is a simple no operation handler. impl AsyncEventHandlerMut for () { fn handle(&mut self, _: Packet) -> impl Future + Send + Sync { async {} diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index 23b1744..65b4462 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -132,6 +132,9 @@ pub mod tokio; /// Wraps all other errors that can be encountered. pub mod error; +/// All event handler traits are defined here. +/// +/// Event handlers are used to process incoming packets. mod event_handlers; /// All MQTT packets are defined here pub mod packets; diff --git a/mqrstt/src/packets/primitive/fixed_header.rs b/mqrstt/src/packets/primitive/fixed_header.rs index 9f28a71..ee5a98b 100644 --- a/mqrstt/src/packets/primitive/fixed_header.rs +++ b/mqrstt/src/packets/primitive/fixed_header.rs @@ -10,16 +10,18 @@ use crate::packets::{ use super::read_fixed_header_rem_len; /// 2.1.1 Fixed Header -/// ``` -/// 7 3 0 -/// +--------------------------+--------------------------+ -/// byte 1 | MQTT Control Packet Type | Flags for Packet type | -/// +--------------------------+--------------------------+ -/// | Remaining Length | -/// +-----------------------------------------------------+ /// -/// https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021 -/// ``` +/// The fixed header indicates the pakcet type in the first four bits [7 - 4] and for some packets it also contains some flags in the second four bits [3 - 0]. +/// The remaining length encodes the length of the variable header and the payload. +/// +/// | Bit | 7 - 4 | 3 - 0 | +/// |----------|----------------------------|----------------------------| +/// | byte 1 | MQTT Control Packet Type | Flags for Packet type | +/// | | | | +/// | byte 2+ | Remaining Length | +/// | |---------------------------------------------------------| +/// +/// [MQTT v5.0 Specification](https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901021) #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] pub(crate) struct FixedHeader { pub packet_type: PacketType, From 2adf952ba6c4dd0913874d3322f38db22f2bc36f Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:54:00 +0100 Subject: [PATCH 14/45] added write to all packets --- mqrstt/src/packets/auth/mod.rs | 17 +- mqrstt/src/packets/connack/mod.rs | 31 +++ mqrstt/src/packets/connect/connect_flags.rs | 17 +- mqrstt/src/packets/connect/last_will.rs | 17 +- mqrstt/src/packets/connect/mod.rs | 54 +++++ mqrstt/src/packets/disconnect/mod.rs | 18 ++ mqrstt/src/packets/error.rs | 5 +- .../src/packets/macros/properties_macros.rs | 210 +++++++++++++++++ .../src/packets/macros/reason_code_macros.rs | 221 ++++++++++-------- mqrstt/src/packets/mod.rs | 165 ++++++++++++- mqrstt/src/packets/mqtt_trait/mod.rs | 5 + .../src/packets/mqtt_trait/primitive_impl.rs | 127 +++++++++- mqrstt/src/packets/primitive/property_type.rs | 25 +- .../src/packets/primitive/protocol_version.rs | 15 +- mqrstt/src/packets/primitive/qos.rs | 31 ++- .../src/packets/primitive/variable_integer.rs | 65 +++++- mqrstt/src/packets/puback/mod.rs | 21 ++ mqrstt/src/packets/pubcomp/mod.rs | 40 +++- mqrstt/src/packets/publish/mod.rs | 36 +++ mqrstt/src/packets/pubrec/mod.rs | 20 ++ mqrstt/src/packets/pubrel/mod.rs | 20 ++ mqrstt/src/packets/suback/mod.rs | 26 ++- mqrstt/src/packets/subscribe/mod.rs | 37 ++- mqrstt/src/packets/unsuback/mod.rs | 23 ++ mqrstt/src/packets/unsubscribe/mod.rs | 21 +- mqrstt/src/smol/network.rs | 2 +- mqrstt/src/tests/test_packets.rs | 21 ++ mqrstt/src/tokio/stream/mod.rs | 15 +- 28 files changed, 1156 insertions(+), 149 deletions(-) diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs index 412d87d..32d207a 100644 --- a/mqrstt/src/packets/auth/mod.rs +++ b/mqrstt/src/packets/auth/mod.rs @@ -1,4 +1,6 @@ mod properties; +use std::future::Future; + pub use properties::AuthProperties; mod reason_code; pub use reason_code::AuthReasonCode; @@ -6,7 +8,7 @@ pub use reason_code::AuthReasonCode; use bytes::Bytes; use super::{ - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite, PacketAsyncRead, PacketAsyncWrite, PacketRead, PacketWrite, WireLength}, VariableInteger, }; @@ -42,6 +44,19 @@ where } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for Auth +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl Future> { + async move { + let reason_code_writen = self.reason_code.async_write(stream).await?; + let properties_writen = self.properties.async_write(stream).await?; + Ok(reason_code_writen + properties_writen) + } + } +} + impl PacketWrite for Auth { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { self.reason_code.write(buf)?; diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 1caf53b..23a49f7 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -80,6 +80,22 @@ impl PacketWrite for ConnAck { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for ConnAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + use crate::packets::mqtt_trait::MqttAsyncWrite; + let connack_flags_writen = self.connack_flags.async_write(stream).await?; + let reason_code_writen = self.reason_code.async_write(stream).await?; + let connack_properties_writen = self.connack_properties.async_write(stream).await?; + + Ok(connack_flags_writen + reason_code_writen + connack_properties_writen) + } + } +} + impl WireLength for ConnAck { fn wire_len(&self) -> usize { 2 + // 1 for connack_flags and 1 for reason_code @@ -133,6 +149,21 @@ impl MqttWrite for ConnAckFlags { } } +impl crate::packets::mqtt_trait::MqttAsyncWrite for ConnAckFlags +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + use tokio::io::AsyncWriteExt; + let byte = self.session_present as u8; + + stream.write_u8(byte).await?; + Ok(1) + } + } +} + #[cfg(test)] mod tests { diff --git a/mqrstt/src/packets/connect/connect_flags.rs b/mqrstt/src/packets/connect/connect_flags.rs index 2066b51..c2a576e 100644 --- a/mqrstt/src/packets/connect/connect_flags.rs +++ b/mqrstt/src/packets/connect/connect_flags.rs @@ -4,7 +4,7 @@ use tokio::io::AsyncReadExt; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, QoS, }; @@ -91,3 +91,18 @@ impl MqttWrite for ConnectFlags { Ok(()) } } + +impl MqttAsyncWrite for ConnectFlags +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + use tokio::io::AsyncWriteExt; + let byte = self.into_u8()?; + stream.write_u8(byte).await?; + + Ok(1) + } + } +} diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs index 5d89ec0..b93e88e 100644 --- a/mqrstt/src/packets/connect/last_will.rs +++ b/mqrstt/src/packets/connect/last_will.rs @@ -2,7 +2,7 @@ use bytes::{Bytes, BytesMut}; use crate::packets::{ error::{DeserializeError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, QoS, WireLength, }; @@ -78,6 +78,21 @@ impl MqttWrite for LastWill { } } +impl MqttAsyncWrite for LastWill +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let properties_written = self.last_will_properties.async_write(stream).await?; + let topic_written = self.topic.async_write(stream).await?; + let payload_written = self.payload.async_write(stream).await?; + + Ok(properties_written + topic_written + payload_written) + } + } +} + impl WireLength for LastWill { fn wire_len(&self) -> usize { let property_len = self.last_will_properties.wire_len(); diff --git a/mqrstt/src/packets/connect/mod.rs b/mqrstt/src/packets/connect/mod.rs index 1674db2..463f6e4 100644 --- a/mqrstt/src/packets/connect/mod.rs +++ b/mqrstt/src/packets/connect/mod.rs @@ -215,6 +215,60 @@ impl PacketWrite for Connect { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for Connect +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_writen_bytes = 6 // protocol header + + 1 // protocol version + + 1 // connect flags + + 2; // keep alive + let protocol = [0x00, 0x04, b'M', b'Q', b'T', b'T']; + // We allready start with 6 as total writen bytes thus dont add anymore + stream.write_all(&protocol).await?; + + self.protocol_version.async_write(stream).await?; + + let mut connect_flags = ConnectFlags { + clean_start: self.clean_start, + username: self.username.is_some(), + password: self.password.is_some(), + ..Default::default() + }; + + if let Some(last_will) = &self.last_will { + connect_flags.will_flag = true; + connect_flags.will_retain = last_will.retain; + connect_flags.will_qos = last_will.qos; + } + + connect_flags.async_write(stream).await?; + + stream.write_u16(self.keep_alive).await?; + + total_writen_bytes += self.connect_properties.async_write(stream).await?; + + total_writen_bytes += self.client_id.async_write(stream).await?; + + if let Some(last_will) = &self.last_will { + total_writen_bytes += last_will.async_write(stream).await?; + } + if let Some(username) = &self.username { + total_writen_bytes += username.async_write(stream).await?; + } + if let Some(password) = &self.password { + total_writen_bytes += password.async_write(stream).await?; + } + + Ok(total_writen_bytes) + } + } +} + impl WireLength for Connect { fn wire_len(&self) -> usize { let mut len = "MQTT".wire_len() + 1 + 1 + 2; // protocol version, connect_flags and keep alive diff --git a/mqrstt/src/packets/disconnect/mod.rs b/mqrstt/src/packets/disconnect/mod.rs index e5d34ee..d8c1211 100644 --- a/mqrstt/src/packets/disconnect/mod.rs +++ b/mqrstt/src/packets/disconnect/mod.rs @@ -39,6 +39,24 @@ where } } } + +impl crate::packets::mqtt_trait::PacketAsyncWrite for Disconnect +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_written_bytes = 0; + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} + impl PacketRead for Disconnect { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let reason_code; diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index 53abcc8..68ede44 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -46,6 +46,9 @@ pub enum DeserializeError { #[error("Read more data for the packet than indicated length")] ReadTooMuchData(&'static str, usize, usize), + #[error("While reading a packet {read} bytes was read, but the packet indicated a remaining length of {remaining_length} bytes")] + RemainingDataError { read: usize, remaining_length: usize }, + #[error("Reason code {0} is not allowed for packet type {1:?}")] UnexpectedReasonCode(u8, PacketType), @@ -72,7 +75,7 @@ impl From for DeserializeError { } #[derive(Error, Clone, Debug)] -pub(crate) enum ReadBytes { +pub enum ReadBytes { #[error("Normal error")] Err(#[from] T), diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index 89c20ed..3befb6d 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -33,6 +33,43 @@ macro_rules! define_properties { } } + impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + Unpin { + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let mut bytes_writen = 0; + $( + $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); + )* + + Ok(bytes_writen) + } + + // let (len, length_variable_integer) = ::read_async_variable_integer(stream).await?; + // if len == 0 { + // return Ok((Self::default(), length_variable_integer)); + // } + + // let mut properties = $name::default(); + + // let mut read_property_bytes = 0; + // loop { + // let (prop, read_bytes) = crate::packets::PropertyType::async_read(stream).await?; + // read_property_bytes += read_bytes; + // match prop { + // $( + // $crate::packets::macros::properties_read_match_branch_name!($prop_variant) => $crate::packets::macros::properties_read_match_branch_body!(stream, properties, read_property_bytes, PropertyType::$prop_variant), + // )* + // e => return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), + // } + // if read_property_bytes == len { + // break; + // } + // } + + // Ok((properties, length_variable_integer + read_property_bytes)) + } + } + impl $crate::packets::mqtt_trait::WireLength for $name { fn wire_len(&self) -> usize { let mut len: usize = 0; @@ -543,6 +580,178 @@ macro_rules! properties_read_match_branch_name { }; } +macro_rules! properties_write { + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::PayloadFormatIndicator) => { + if let Some(payload_format_indicator) = &($self.payload_format_indicator) { + $bytes_writen += PropertyType::PayloadFormatIndicator.async_write($stream).await?; + $bytes_writen += payload_format_indicator.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MessageExpiryInterval) => { + if let Some(message_expiry_interval) = &($self.message_expiry_interval) { + $bytes_writen += PropertyType::MessageExpiryInterval.async_write($stream).await?; + $bytes_writen += message_expiry_interval.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ContentType) => { + if let Some(content_type) = &($self.content_type) { + $bytes_writen += PropertyType::ContentType.async_write($stream).await?; + $bytes_writen += content_type.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ResponseTopic) => { + if let Some(response_topic) = &($self.response_topic) { + $bytes_writen += PropertyType::ResponseTopic.async_write($stream).await?; + $bytes_writen += response_topic.as_ref().async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::CorrelationData) => { + if let Some(correlation_data) = &($self.correlation_data) { + $bytes_writen += PropertyType::CorrelationData.async_write($stream).await?; + $bytes_writen += correlation_data.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SubscriptionIdentifier) => { + if let Some(sub_id) = &($self.subscription_identifier) { + $bytes_writen += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_writen += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ListSubscriptionIdentifier) => { + for sub_id in &($self.subscription_identifiers) { + $bytes_writen += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_writen += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SessionExpiryInterval) => { + if let Some(session_expiry_interval) = &($self.session_expiry_interval) { + $bytes_writen += PropertyType::SessionExpiryInterval.async_write($stream).await?; + $bytes_writen += session_expiry_interval.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AssignedClientIdentifier) => {}; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ServerKeepAlive) => { + if let Some(server_keep_alive) = &($self.server_keep_alive) { + $bytes_writen += PropertyType::ServerKeepAlive.async_write($stream).await?; + $bytes_writen += server_keep_alive.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AuthenticationMethod) => { + if let Some(authentication_method) = &($self.authentication_method) { + $bytes_writen += PropertyType::AuthenticationMethod.async_write($stream).await?; + $bytes_writen += authentication_method.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AuthenticationData) => { + if let Some(authentication_data) = &($self.authentication_data) { + if !authentication_data.is_empty() && ($self.authentication_method).is_some() { + $bytes_writen += PropertyType::AuthenticationData.async_write($stream).await?; + $bytes_writen += authentication_data.async_write($stream).await?; + } + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RequestProblemInformation) => { + if let Some(request_problem_information) = &($self.request_problem_information) { + $bytes_writen += PropertyType::RequestProblemInformation.async_write($stream).await?; + $bytes_writen += request_problem_information.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::WillDelayInterval) => { + if let Some(delay_interval) = &($self.will_delay_interval) { + $bytes_writen += PropertyType::WillDelayInterval.async_write($stream).await?; + $bytes_writen += delay_interval.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RequestResponseInformation) => { + if let Some(request_response_information) = &($self.request_response_information) { + $bytes_writen += PropertyType::RequestResponseInformation.async_write($stream).await?; + $bytes_writen += request_response_information.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ResponseInformation) => { + if let Some(response_info) = &($self.response_info) { + $bytes_writen += PropertyType::ResponseInformation.async_write($stream).await?; + $bytes_writen += response_info.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ServerReference) => { + if let Some(server_refrence) = &($self.server_reference) { + $bytes_writen += PropertyType::ServerReference.async_write($stream).await?; + server_refrence.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ReasonString) => { + if let Some(reason_string) = &($self.reason_string) { + $bytes_writen += PropertyType::ReasonString.async_write($stream).await?; + $bytes_writen += reason_string.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ReceiveMaximum) => { + if let Some(receive_maximum) = &($self.receive_maximum) { + $bytes_writen += PropertyType::ReceiveMaximum.async_write($stream).await?; + $bytes_writen += receive_maximum.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::TopicAliasMaximum) => { + if let Some(topic_alias_maximum) = &($self.topic_alias_maximum) { + $bytes_writen += PropertyType::TopicAliasMaximum.async_write($stream).await?; + $bytes_writen += topic_alias_maximum.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::TopicAlias) => { + if let Some(topic_alias) = &($self.topic_alias) { + $bytes_writen += PropertyType::TopicAlias.async_write($stream).await?; + $bytes_writen += topic_alias.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MaximumQos) => { + if let Some(maximum_qos) = &($self.maximum_qos) { + $bytes_writen += PropertyType::MaximumQos.async_write($stream).await?; + $bytes_writen += maximum_qos.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RetainAvailable) => { + if let Some(retain_available) = &($self.retain_available) { + $bytes_writen += PropertyType::RetainAvailable.async_write($stream).await?; + $bytes_writen += retain_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::UserProperty) => { + for (key, value) in &($self.user_properties) { + $bytes_writen += PropertyType::UserProperty.async_write($stream).await?; + $bytes_writen += key.async_write($stream).await?; + $bytes_writen += value.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MaximumPacketSize) => { + if let Some(maximum_packet_size) = &($self.maximum_packet_size) { + $bytes_writen += PropertyType::MaximumPacketSize.async_write($stream).await?; + $bytes_writen += maximum_packet_size.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::WildcardSubscriptionAvailable) => { + if let Some(wildcards_available) = &($self.wildcards_available) { + $bytes_writen += PropertyType::WildcardSubscriptionAvailable.async_write($stream).await?; + $bytes_writen += wildcards_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SubscriptionIdentifierAvailable) => { + if let Some(subscription_ids_available) = &($self.subscription_ids_available) { + $bytes_writen += PropertyType::SubscriptionIdentifierAvailable.async_write($stream).await?; + $bytes_writen += subscription_ids_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SharedSubscriptionAvailable) => { + if let Some(shared_subscription_available) = &($self.shared_subscription_available) { + $bytes_writen += PropertyType::SharedSubscriptionAvailable.async_write($stream).await?; + $bytes_writen += shared_subscription_available.async_write($stream).await?; + } + }; + ($self:ident, $bytes_writen:ident, $stream:ident, $unknown:ident) => { + compile_error!(concat!("Unknown property: ", stringify!($unknown))); + }; +} + macro_rules! properties_wire_length { ($self:ident, $len:ident, PropertyType::PayloadFormatIndicator) => { if $self.payload_format_indicator.is_some() { @@ -698,3 +907,4 @@ pub(crate) use properties_read_match_branch_body; pub(crate) use properties_read_match_branch_name; pub(crate) use properties_struct; pub(crate) use properties_wire_length; +pub(crate) use properties_write; diff --git a/mqrstt/src/packets/macros/reason_code_macros.rs b/mqrstt/src/packets/macros/reason_code_macros.rs index b56a669..8e1fa29 100644 --- a/mqrstt/src/packets/macros/reason_code_macros.rs +++ b/mqrstt/src/packets/macros/reason_code_macros.rs @@ -7,12 +7,24 @@ macro_rules! reason_code { $($code),* } + impl $name { + pub(crate) fn from_u8(val: u8) -> Result { + $crate::packets::macros::reason_code_match!(@ $name, val, { + $($code,)* + } -> ()) + } + + pub(crate) fn to_u8(&self) -> u8 { + $crate::packets::macros::reason_code_match_write!(@ $name, self, { + $($code,)* + } -> ()) + } + } + impl $crate::packets::mqtt_trait::MqttAsyncRead for $name where S: tokio::io::AsyncRead + std::marker::Unpin{ async fn async_read(stream: &mut S) -> Result<(Self, usize), $crate::packets::error::ReadError> { let input = stream.read_u8().await?; - let res = $crate::packets::macros::reason_code_match!(@ $name, input, { - $($code,)* - } -> ())?; + let res = Self::from_u8(input)?; Ok((res, 1)) } } @@ -23,19 +35,24 @@ macro_rules! reason_code { return Err($crate::packets::error::DeserializeError::InsufficientData(std::any::type_name::(), 0, 1)); } use bytes::Buf; - let res = buf.get_u8(); - $crate::packets::macros::reason_code_match!(@ $name, res, { - $($code,)* - } -> ()) + let input = buf.get_u8(); + Self::from_u8(input) + } + } + + impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + std::marker::Unpin{ + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let val = self.to_u8(); + stream.write_u8(val).await?; + Ok(1) } } impl $crate::packets::mqtt_trait::MqttWrite for $name { fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), $crate::packets::error::SerializeError> { - let val = $crate::packets::macros::reason_code_match_write!(@ $name, buf, self, { - $($code,)* - } -> ()); use bytes::BufMut; + let val = self.to_u8(); buf.put_u8(val); Ok(()) } @@ -327,285 +344,285 @@ macro_rules! reason_code_match { } macro_rules! reason_code_match_write{ - ( @ $name:ident, $buf:ident, $input:ident, { } -> ($($result:tt)*) ) => ( + ( @ $name:ident, $input:ident, { } -> ($($result:tt)*) ) => ( match $input { $($result)* } ); - ( @ $name:ident, $buf:ident, $input:ident, { Success, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { Success, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::Success => 0x00, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { NormalDisconnection, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { NormalDisconnection, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::NormalDisconnection => 0x00, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { GrantedQoS0, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { GrantedQoS0, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::GrantedQoS0 => 0x00, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { GrantedQoS1, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { GrantedQoS1, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::GrantedQoS1 => 0x01, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { GrantedQoS2, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { GrantedQoS2, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::GrantedQoS2 => 0x02, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { DisconnectWithWillMessage, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::DisconnectWithWillMessage => 0x04, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { NoMatchingSubscribers, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { NoMatchingSubscribers, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::NoMatchingSubscribers => 0x10, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { NoSubscriptionExisted, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { NoSubscriptionExisted, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::NoSubscriptionExisted => 0x11, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ContinueAuthentication, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ContinueAuthentication => 0x18, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ReAuthenticate, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ReAuthenticate, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ReAuthenticate => 0x19, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { UnspecifiedError, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { UnspecifiedError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::UnspecifiedError => 0x80, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { MalformedPacket, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { MalformedPacket, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::MalformedPacket => 0x81, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ProtocolError, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ProtocolError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ProtocolError => 0x82, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ImplementationSpecificError, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ImplementationSpecificError, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ImplementationSpecificError => 0x83, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { UnsupportedProtocolVersion, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { UnsupportedProtocolVersion, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::UnsupportedProtocolVersion => 0x84, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ClientIdentifierNotValid, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ClientIdentifierNotValid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ClientIdentifierNotValid => 0x85, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { BadUsernameOrPassword, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { BadUsernameOrPassword, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::BadUsernameOrPassword => 0x86, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { NotAuthorized, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { NotAuthorized, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::NotAuthorized => 0x87, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ServerUnavailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ServerUnavailable, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ServerUnavailable => 0x88, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ServerBusy, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ServerBusy, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ServerBusy => 0x89, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { Banned, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { Banned, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::Banned => 0x8A, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ServerShuttingDown, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ServerShuttingDown, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ServerShuttingDown => 0x8B , )) ); - ( @ $name:ident, $buf:ident, $input:ident, { BadAuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { BadAuthenticationMethod, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::BadAuthenticationMethod => 0x8C, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { KeepAliveTimeout, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { KeepAliveTimeout, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::KeepAliveTimeout => 0x8D, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { SessionTakenOver, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { SessionTakenOver, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::SessionTakenOver => 0x8E, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { TopicFilterInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { TopicFilterInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::TopicFilterInvalid => 0x8F, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { TopicNameInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { TopicNameInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::TopicNameInvalid => 0x90, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { PacketIdentifierInUse, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { PacketIdentifierInUse, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::PacketIdentifierInUse => 0x91, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { PacketIdentifierNotFound, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { PacketIdentifierNotFound, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::PacketIdentifierNotFound => 0x92, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ReceiveMaximumExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ReceiveMaximumExceeded => 0x93, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { TopicAliasInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { TopicAliasInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::TopicAliasInvalid => 0x94, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { PacketTooLarge, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { PacketTooLarge, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::PacketTooLarge => 0x95, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { MessageRateTooHigh, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { MessageRateTooHigh, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::MessageRateTooHigh => 0x96, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { QuotaExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { QuotaExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::QuotaExceeded => 0x97, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { AdministrativeAction, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { AdministrativeAction, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::AdministrativeAction => 0x98, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { PayloadFormatInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { PayloadFormatInvalid, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::PayloadFormatInvalid => 0x99, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { RetainNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { RetainNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::RetainNotSupported => 0x9A, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { QosNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { QosNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::QosNotSupported => 0x9B, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { UseAnotherServer, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { UseAnotherServer, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::UseAnotherServer => 0x9C, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ServerMoved, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ServerMoved, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ServerMoved => 0x9D, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { SharedSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { SharedSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::SharedSubscriptionsNotSupported => 0x9E, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { ConnectionRateExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { ConnectionRateExceeded, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::ConnectionRateExceeded => 0x9F, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { MaximumConnectTime, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { MaximumConnectTime, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::MaximumConnectTime => 0xA0, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { SubscriptionIdentifiersNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { SubscriptionIdentifiersNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::SubscriptionIdentifiersNotSupported => 0xA1, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { WildcardSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( - $crate::packets::macros::reason_code_match_write!(@ $name, $buf, $input, { $($rest)* } -> ( + ( @ $name:ident, $input:ident, { WildcardSubscriptionsNotSupported, $($rest:tt)* } -> ($($result:tt)*) ) => ( + $crate::packets::macros::reason_code_match_write!(@ $name, $input, { $($rest)* } -> ( $($result)* $name::WildcardSubscriptionsNotSupported => 0xA2, )) ); - ( @ $name:ident, $buf:ident, $input:ident, { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( + ( @ $name:ident, $input:ident, { $unknown:ident, $($rest:tt)* } -> ($($result:tt)*) ) => ( compile_error!(concat!("Unknown reason_code: ", stringify!($unknown))) ); } diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index c77969f..208ae19 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -18,8 +18,8 @@ mod unsuback; mod unsubscribe; mod primitive; -use error::ReadError; -use mqtt_trait::PacketAsyncRead; +use error::{ReadError, WriteError}; +use mqtt_trait::{PacketAsyncRead, PacketAsyncWrite}; pub use primitive::*; pub use auth::*; @@ -83,6 +83,38 @@ impl Packet { } } + pub(crate) fn first_byte(&self) -> u8 { + match self { + Packet::Connect(_) => 0b0001_0000, + Packet::ConnAck(_) => 0b0010_0000, + Packet::Publish(p) => { + let mut first_byte = 0b0011_0000u8; + if p.dup { + first_byte |= 0b1000; + } + first_byte |= p.qos.into_u8() << 1; + if p.retain { + first_byte |= 0b0001; + } + first_byte + } + Packet::PubAck(_) => 0b0100_0000, + Packet::PubRec(_) => 0b0101_0000, + Packet::PubRel(_) => 0b0110_0010, + Packet::PubComp(_) => 0b0111_0000, + Packet::Subscribe(_) => 0b1000_0010, + Packet::SubAck(_) => { + unreachable!() + } + Packet::Unsubscribe(_) => 0b1010_0010, + Packet::UnsubAck(_) => 0b1011_0000, + Packet::PingReq => 0b1100_0000, + Packet::PingResp => 0b1101_0000, + Packet::Disconnect(_) => 0b1110_0000, + Packet::Auth(_) => 0b1111_0000, + } + } + pub(crate) fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { match self { Packet::Connect(p) => { @@ -170,6 +202,97 @@ impl Packet { Ok(()) } + pub(crate) async fn async_write(&self, stream: &mut S) -> Result + where + S: tokio::io::AsyncWrite + Unpin, + { + use tokio::io::AsyncWriteExt; + let mut written = 1; + match self { + Packet::Connect(p) => { + stream.write_u8(0b0001_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::ConnAck(p) => { + stream.write_u8(0b0010_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::Publish(p) => { + let mut first_byte = 0b0011_0000u8; + if p.dup { + first_byte |= 0b1000; + } + + first_byte |= p.qos.into_u8() << 1; + + if p.retain { + first_byte |= 0b0001; + } + stream.write_u8(first_byte).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubAck(p) => { + stream.write_u8(0b0100_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubRec(p) => { + stream.write_u8(0b0101_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubRel(p) => { + stream.write_u8(0b0110_0010).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::PubComp(p) => { + stream.write_u8(0b0111_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::Subscribe(p) => { + stream.write_u8(0b1000_0010).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::SubAck(_) => { + unreachable!() + } + Packet::Unsubscribe(p) => { + stream.write_u8(0b1010_0010).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::UnsubAck(_) => { + unreachable!(); + // stream.write_u8(0b1011_0000).await?; + } + Packet::PingReq => { + stream.write_u8(0b1100_0000).await?; + stream.write_u8(0).await?; // Variable header length. + } + Packet::PingResp => { + stream.write_u8(0b1101_0000).await?; + stream.write_u8(0).await?; // Variable header length. + } + Packet::Disconnect(p) => { + stream.write_u8(0b1110_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + Packet::Auth(p) => { + stream.write_u8(0b1111_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; + } + } + Ok(written) + } + pub(crate) fn read(header: FixedHeader, buf: Bytes) -> Result { let packet = match header.packet_type { PacketType::Connect => Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?), @@ -215,8 +338,7 @@ impl Packet { Ok(packet) } - #[cfg(test)] - pub(crate) async fn async_read_from_buffer(stream: &mut S) -> Result + pub async fn async_read_from_buffer(stream: &mut S) -> Result where S: tokio::io::AsyncRead + Unpin, { @@ -225,8 +347,7 @@ impl Packet { Ok(Packet::async_read(header, stream).await?) } - #[cfg(test)] - pub(crate) fn read_from_buffer(buffer: &mut BytesMut) -> Result> { + pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> { use bytes::Buf; use error::ReadBytes; @@ -327,7 +448,26 @@ mod tests { use crate::packets::Packet; - use crate::tests::test_packets::{disconnect_case, ping_req_case, ping_resp_case, publish_case, pubrel_case, pubrel_smallest_case}; + use crate::tests::test_packets::{create_empty_publish_packet, disconnect_case, ping_req_case, ping_resp_case, publish_case, pubrel_case, pubrel_smallest_case}; + + #[rstest::rstest] + // #[case(disconnect_case())] + // #[case(ping_req_case())] + // #[case(ping_resp_case())] + // #[case(publish_case())] + // #[case(pubrel_case())] + #[case(create_empty_publish_packet())] + fn test_write_read_write_read_cases(#[case] packet: Packet) { + let mut buffer = BytesMut::new(); + packet.write(&mut buffer).unwrap(); + let res1 = Packet::read_from_buffer(&mut buffer).unwrap(); + + let mut buffer = BytesMut::new(); + res1.write(&mut buffer).unwrap(); + let res2 = Packet::read_from_buffer(&mut buffer).unwrap(); + + assert_eq!(res1, res2); + } #[rstest::rstest] #[case(disconnect_case())] @@ -378,4 +518,15 @@ mod tests { // assert_eq!() } + + // #[rstest::rstest] + // #[case(&[59, 1, 0, 59])] + // #[case(&[16, 14, 0, 4, 77, 81, 84, 84, 5, 247, 247, 252, 1, 17, 247, 247, 247])] + // fn test_read_error(#[case] bytes: &[u8]) { + // let mut buffer = BytesMut::from_iter(bytes); + + // let res = Packet::read_from_buffer(&mut buffer); + + // assert!(res.is_err()); + // } } diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs index daada69..68d32ce 100644 --- a/mqrstt/src/packets/mqtt_trait/mod.rs +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -51,6 +51,11 @@ where ::write(self, buf) } } +pub(crate) trait MqttAsyncWrite: Sized { + /// Write `Self` to the provided stream. + /// Returns the deserialized instance and the number of bytes read from the stream. + fn async_write(&self, stream: &mut S) -> impl Future>; +} pub trait PacketValidation: Sized { fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; diff --git a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs index 503c4f5..8ff217a 100644 --- a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs +++ b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs @@ -1,8 +1,12 @@ +use tokio::io::AsyncWriteExt; + use bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::packets::error::{DeserializeError, ReadError, SerializeError}; use crate::packets::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, WireLength}; +use super::MqttAsyncWrite; + impl MqttRead for Box { #[inline] fn read(buf: &mut Bytes) -> Result { @@ -35,6 +39,20 @@ impl MqttWrite for Box { } } +impl MqttAsyncWrite for Box +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) + } + } +} + impl WireLength for Box { #[inline(always)] fn wire_len(&self) -> usize { @@ -51,6 +69,20 @@ impl MqttWrite for &str { } } +impl MqttAsyncWrite for &str +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) + } + } +} + impl WireLength for &str { #[inline(always)] fn wire_len(&self) -> usize { @@ -95,6 +127,19 @@ impl MqttWrite for String { Ok(()) } } +impl MqttAsyncWrite for String +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) + } + } +} impl WireLength for String { #[inline(always)] @@ -106,6 +151,9 @@ impl WireLength for String { impl MqttRead for Bytes { #[inline] fn read(buf: &mut Bytes) -> Result { + if buf.len() < 2 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); + } let len = buf.get_u16() as usize; if len > buf.len() { @@ -137,6 +185,20 @@ impl MqttWrite for Bytes { Ok(()) } } +impl MqttAsyncWrite for Bytes +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_ref()).await?; + Ok(2 + self.len()) + } + } +} + impl WireLength for Bytes { #[inline(always)] fn wire_len(&self) -> usize { @@ -147,6 +209,9 @@ impl WireLength for Bytes { impl MqttRead for Vec { #[inline] fn read(buf: &mut Bytes) -> Result { + if buf.len() < 2 { + return Err(DeserializeError::InsufficientData(std::any::type_name::(), buf.len(), 2)); + } let len = buf.get_u16() as usize; if len > buf.len() { @@ -165,6 +230,19 @@ impl MqttWrite for Vec { Ok(()) } } +impl MqttAsyncWrite for Vec +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self).await?; + Ok(2 + self.len()) + } + } +} impl WireLength for Vec { #[inline(always)] fn wire_len(&self) -> usize { @@ -222,7 +300,21 @@ impl MqttWrite for bool { } } } - +impl MqttAsyncWrite for bool +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + if *self { + stream.write_all(&[1]).await?; + } else { + stream.write_all(&[0]).await?; + } + Ok(1) + } + } +} impl MqttRead for u8 { #[inline] fn read(buf: &mut Bytes) -> Result { @@ -240,6 +332,17 @@ where Ok((buf.read_u8().await?, 1)) } } +impl MqttAsyncWrite for u8 +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(1) + } + } +} impl MqttRead for u16 { #[inline] @@ -265,6 +368,17 @@ impl MqttWrite for u16 { Ok(()) } } +impl MqttAsyncWrite for u16 +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(2) + } + } +} impl MqttRead for u32 { #[inline] @@ -289,3 +403,14 @@ impl MqttWrite for u32 { Ok(()) } } +impl MqttAsyncWrite for u32 +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(4) + } + } +} diff --git a/mqrstt/src/packets/primitive/property_type.rs b/mqrstt/src/packets/primitive/property_type.rs index 54bd00f..cce56e1 100644 --- a/mqrstt/src/packets/primitive/property_type.rs +++ b/mqrstt/src/packets/primitive/property_type.rs @@ -1,8 +1,10 @@ +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use bytes::{Buf, BufMut, Bytes, BytesMut}; use crate::packets::{ error::{DeserializeError, ReadError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -123,12 +125,12 @@ impl MqttRead for PropertyType { } } -impl MqttAsyncRead for PropertyType +impl MqttAsyncRead for PropertyType where - T: tokio::io::AsyncReadExt + std::marker::Unpin, + S: tokio::io::AsyncRead + std::marker::Unpin, { - async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { - match buf.read_u8().await { + async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> { + match stream.read_u8().await { Ok(t) => Ok((t.try_into()?, 1)), Err(e) => Err(ReadError::IoError(e)), } @@ -141,3 +143,16 @@ impl MqttWrite for PropertyType { Ok(()) } } + +impl MqttAsyncWrite for PropertyType +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let buf: [u8; 1] = [u8::from(self)]; + stream.write_all(&buf).await?; + Ok(1) + } + } +} diff --git a/mqrstt/src/packets/primitive/protocol_version.rs b/mqrstt/src/packets/primitive/protocol_version.rs index 84ca01a..bed9a18 100644 --- a/mqrstt/src/packets/primitive/protocol_version.rs +++ b/mqrstt/src/packets/primitive/protocol_version.rs @@ -4,7 +4,7 @@ use tokio::io::AsyncReadExt; use crate::packets::{ error::{DeserializeError, ReadError, SerializeError}, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, }; /// Protocol version of the MQTT connection @@ -22,6 +22,19 @@ impl MqttWrite for ProtocolVersion { } } +impl MqttAsyncWrite for ProtocolVersion +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use tokio::io::AsyncWriteExt; + async move { + stream.write_u8(5).await?; + Ok(1) + } + } +} + impl MqttRead for ProtocolVersion { fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { diff --git a/mqrstt/src/packets/primitive/qos.rs b/mqrstt/src/packets/primitive/qos.rs index f019587..572a6a5 100644 --- a/mqrstt/src/packets/primitive/qos.rs +++ b/mqrstt/src/packets/primitive/qos.rs @@ -1,7 +1,11 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; -use crate::packets::{error::{DeserializeError, ReadError, SerializeError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite}}; +use crate::packets::{ + error::{DeserializeError, ReadError, SerializeError}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite}, +}; +use tokio::io::AsyncWriteExt; /// Quality of service #[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -45,7 +49,10 @@ impl MqttRead for QoS { } } -impl MqttAsyncRead for QoS where T: tokio::io::AsyncReadExt + std::marker::Unpin { +impl MqttAsyncRead for QoS +where + T: tokio::io::AsyncReadExt + std::marker::Unpin, +{ async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> { match buf.read_u8().await { Ok(0) => Ok((QoS::AtMostOnce, 1)), @@ -60,12 +67,20 @@ impl MqttAsyncRead for QoS where T: tokio::io::AsyncReadExt + std::marker: impl MqttWrite for QoS { #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { - let val = match self { - QoS::AtMostOnce => 0, - QoS::AtLeastOnce => 1, - QoS::ExactlyOnce => 2, - }; + let val = self.into_u8(); buf.put_u8(val); Ok(()) } -} \ No newline at end of file +} +impl MqttAsyncWrite for QoS +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async move { + let buf: [u8; 1] = [self.into_u8()]; + stream.write_all(&buf).await?; + Ok(1) + } + } +} diff --git a/mqrstt/src/packets/primitive/variable_integer.rs b/mqrstt/src/packets/primitive/variable_integer.rs index 673fe36..276d613 100644 --- a/mqrstt/src/packets/primitive/variable_integer.rs +++ b/mqrstt/src/packets/primitive/variable_integer.rs @@ -1,9 +1,11 @@ +use crate::packets::error::WriteError; use crate::packets::error::{DeserializeError, ReadBytes, ReadError, SerializeError}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use core::slice::Iter; use std::future::Future; use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; pub(crate) fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { let mut integer = 0; @@ -45,7 +47,8 @@ pub(crate) trait VariableInteger: Sized { fn variable_integer_len(&self) -> usize; fn write_variable_integer(&self, buf: &mut BytesMut) -> Result; fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError>; - fn read_async_variable_integer(stream: &mut S) -> impl Future>; + fn read_async_variable_integer(stream: &mut S) -> impl Future>; + fn write_async_variable_integer(&self, stream: &mut S) -> impl Future>; } impl VariableInteger for usize { @@ -102,7 +105,7 @@ impl VariableInteger for usize { Err(DeserializeError::MalformedPacket) } - fn read_async_variable_integer(stream: &mut S) -> impl Future> { + fn read_async_variable_integer(stream: &mut S) -> impl Future> { async move { let mut integer = 0; let mut length = 0; @@ -120,6 +123,34 @@ impl VariableInteger for usize { Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } } + + fn write_async_variable_integer(&self, stream: &mut S) -> impl Future> { + async move { + let mut buf = [0u8; 4]; + + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); + } + + let mut write = *self; + let mut length = 1; + + for i in 0..4 { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + buf[i] = byte; + if write == 0 { + length = i + 1; + break; + } + } + stream.write_all(&buf[0..length]).await; + Ok(length) + } + } } impl VariableInteger for u32 { @@ -176,7 +207,7 @@ impl VariableInteger for u32 { Err(DeserializeError::MalformedPacket) } - fn read_async_variable_integer(stream: &mut S) -> impl Future> { + fn read_async_variable_integer(stream: &mut S) -> impl Future> { async move { let mut integer = 0; let mut length = 0; @@ -194,4 +225,32 @@ impl VariableInteger for u32 { Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } } + + fn write_async_variable_integer(&self, stream: &mut S) -> impl Future> { + async move { + let mut buf = [0u8; 4]; + + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); + } + + let mut write = *self; + let mut length = 1; + + for i in 0..4 { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + buf[i] = byte; + if write == 0 { + length = i + 1; + break; + } + } + stream.write_all(&buf[0..length]).await; + Ok(length) + } + } } diff --git a/mqrstt/src/packets/puback/mod.rs b/mqrstt/src/packets/puback/mod.rs index e91a56a..7ab0ba0 100644 --- a/mqrstt/src/packets/puback/mod.rs +++ b/mqrstt/src/packets/puback/mod.rs @@ -98,6 +98,27 @@ impl PacketWrite for PubAck { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_written_bytes = 0; + if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_written_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_written_bytes += self.reason_code.async_write(stream).await?; + } else { + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} + impl WireLength for PubAck { fn wire_len(&self) -> usize { if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 24355aa..229036e 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -5,7 +5,7 @@ mod properties; pub use properties::PubCompProperties; use super::{ - error::DeserializeError, + error::{DeserializeError, ReadError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, }; use bytes::BufMut; @@ -77,15 +77,22 @@ where 2, )); } - // Requires u16, u8 and at leasy 1 byte of variable integer prop length so at least 4 bytes + // Requires u16, u8 and at least 1 byte of variable integer prop length so at least 4 bytes else if remaining_length < 4 { - return Err(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4).into()); + return Err(ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4))); } let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; - assert_eq!(2 + reason_code_read_bytes + properties_read_bytes, remaining_length); + let total_read_bytes = 2 + reason_code_read_bytes + properties_read_bytes; + + if total_read_bytes != remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { + read: total_read_bytes, + remaining_length: remaining_length, + })); + } Ok(( Self { @@ -93,7 +100,7 @@ where reason_code, properties, }, - 2 + reason_code_read_bytes + properties_read_bytes, + total_read_bytes, )) } } @@ -104,7 +111,7 @@ impl PacketWrite for PubComp { buf.put_u16(self.packet_identifier); if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - // nothing here + return Ok(()); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { self.reason_code.write(buf)?; } else { @@ -115,6 +122,27 @@ impl PacketWrite for PubComp { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubComp +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_writen_bytes = 0; + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_writen_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_writen_bytes += self.reason_code.async_write(stream).await?; + } else { + total_writen_bytes += self.reason_code.async_write(stream).await?; + total_writen_bytes += self.properties.async_write(stream).await?; + } + Ok(total_writen_bytes) + } + } +} + impl WireLength for PubComp { fn wire_len(&self) -> usize { if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index f95c852..bfb17a3 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -6,6 +6,7 @@ use tokio::io::AsyncReadExt; use bytes::BufMut; use crate::error::PacketValidationError; +use crate::packets::error::ReadError; use crate::util::constants::MAXIMUM_TOPIC_SIZE; use super::mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; @@ -113,6 +114,16 @@ where let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; total_read_bytes += properties_read_bytes; + if total_read_bytes != remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { + read: total_read_bytes, + remaining_length: remaining_length, + })); + } + + if total_read_bytes > remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)); + } let payload_len = remaining_length - total_read_bytes; let mut payload = vec![0u8; payload_len]; let payload_read_bytes = stream.read_exact(&mut payload).await?; @@ -151,6 +162,31 @@ impl PacketWrite for Publish { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for Publish +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 0; + total_written_bytes += self.topic.async_write(stream).await?; + + if let Some(pkid) = self.packet_identifier { + stream.write_u16(pkid).await?; + total_written_bytes += 2; + } + total_written_bytes += self.publish_properties.async_write(stream).await?; + + stream.write_all(&self.payload).await?; + total_written_bytes += self.payload.len(); + + Ok(total_written_bytes) + } + } +} + impl WireLength for Publish { fn wire_len(&self) -> usize { let mut len = self.topic.wire_len(); diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 639fcf6..aa36eaf 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -113,6 +113,26 @@ impl PacketWrite for PubRec { Ok(()) } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubRec +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_writen_bytes = 0; + if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_writen_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_writen_bytes += self.reason_code.async_write(stream).await?; + } else { + total_writen_bytes += self.reason_code.async_write(stream).await?; + total_writen_bytes += self.properties.async_write(stream).await?; + } + Ok(total_writen_bytes) + } + } +} impl WireLength for PubRec { fn wire_len(&self) -> usize { diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index c7beff6..6e42a76 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -108,6 +108,26 @@ impl PacketWrite for PubRel { Ok(()) } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for PubRel +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_writen_bytes = 0; + if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_writen_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_writen_bytes += self.reason_code.async_write(stream).await?; + } else { + total_writen_bytes += self.reason_code.async_write(stream).await?; + total_writen_bytes += self.properties.async_write(stream).await?; + } + Ok(total_writen_bytes) + } + } +} impl WireLength for PubRel { fn wire_len(&self) -> usize { diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index b34c728..4bbd138 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -27,13 +27,10 @@ impl PacketRead for SubAck { let packet_identifier = u16::read(&mut buf)?; let properties = SubAckProperties::read(&mut buf)?; - dbg!("aa"); - let mut reason_codes = vec![]; loop { let reason_code = SubAckReasonCode::read(&mut buf)?; - dbg!(reason_code); reason_codes.push(reason_code); if buf.is_empty() { @@ -95,6 +92,29 @@ impl PacketWrite for SubAck { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for SubAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + total_written_bytes += self.properties.async_write(stream).await?; + + for reason_code in &self.reason_codes { + reason_code.async_write(stream).await?; + } + total_written_bytes += self.reason_codes.len(); + + Ok(total_written_bytes) + } + } +} + #[cfg(test)] mod test { use bytes::BytesMut; diff --git a/mqrstt/src/packets/subscribe/mod.rs b/mqrstt/src/packets/subscribe/mod.rs index 1b9deac..e97cd36 100644 --- a/mqrstt/src/packets/subscribe/mod.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -7,7 +7,7 @@ use crate::{error::PacketValidationError, util::constants::MAXIMUM_TOPIC_SIZE}; use super::{ error::DeserializeError, - mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}, QoS, VariableInteger, }; use bytes::{Buf, BufMut}; @@ -106,6 +106,27 @@ impl PacketWrite for Subscribe { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for Subscribe +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + total_written_bytes += self.properties.async_write(stream).await?; + for (topic, options) in &self.topics { + total_written_bytes += topic.async_write(stream).await?; + total_written_bytes += options.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} + impl WireLength for Subscribe { fn wire_len(&self) -> usize { let mut len = 2; @@ -209,7 +230,19 @@ impl MqttWrite for SubscriptionOptions { } } -// Please describe the retain handling type, what is it used for +impl MqttAsyncWrite for SubscriptionOptions +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use tokio::io::AsyncWriteExt; + async move { + let byte = (self.retain_handling.into_u8() << 4) | ((self.retain_as_publish as u8) << 3) | ((self.no_local as u8) << 2) | self.qos.into_u8(); + stream.write_u8(byte).await?; + Ok(1) + } + } +} /// Controls how retained messages are handled /// diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs index 140c967..da06d52 100644 --- a/mqrstt/src/packets/unsuback/mod.rs +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -93,6 +93,29 @@ impl PacketWrite for UnsubAck { } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for UnsubAck +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + total_written_bytes += self.properties.async_write(stream).await?; + + for reason_code in &self.reason_codes { + reason_code.async_write(stream).await?; + } + total_written_bytes += self.reason_codes.len(); + + Ok(total_written_bytes) + } + } +} + #[cfg(test)] mod tests { use bytes::{Bytes, BytesMut}; diff --git a/mqrstt/src/packets/unsubscribe/mod.rs b/mqrstt/src/packets/unsubscribe/mod.rs index b198dad..6e9af89 100644 --- a/mqrstt/src/packets/unsubscribe/mod.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -7,7 +7,7 @@ use crate::packets::mqtt_trait::MqttAsyncRead; use super::mqtt_trait::{MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketValidation, PacketWrite, WireLength}; use super::VariableInteger; -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::BufMut; use tokio::io::AsyncReadExt; #[derive(Debug, Clone, PartialEq, Eq)] @@ -101,6 +101,25 @@ impl PacketWrite for Unsubscribe { Ok(()) } } +impl crate::packets::mqtt_trait::PacketAsyncWrite for Unsubscribe +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + use tokio::io::AsyncWriteExt; + async move { + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; + + for topic in &self.topics { + total_written_bytes += topic.async_write(stream).await?; + } + + Ok(total_written_bytes) + } + } +} impl WireLength for Unsubscribe { fn wire_len(&self) -> usize { diff --git a/mqrstt/src/smol/network.rs b/mqrstt/src/smol/network.rs index 2799b15..6497dcd 100644 --- a/mqrstt/src/smol/network.rs +++ b/mqrstt/src/smol/network.rs @@ -65,7 +65,7 @@ where S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, { /// Initializes an MQTT connection with the provided configuration an stream - pub(crate) async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { + pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; self.last_network_action = Instant::now(); diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index e004ff8..df93161 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -233,6 +233,27 @@ pub fn create_publish_packet(qos: QoS, dup: bool, retain: bool, packet_identifie }) } +pub fn create_empty_publish_packet() -> Packet { + Packet::Publish(Publish { + dup: false, + qos: QoS::AtMostOnce, + retain: false, + topic: "test/#".into(), + packet_identifier: None, + publish_properties: PublishProperties { + payload_format_indicator: None, + message_expiry_interval: Some(3600), + topic_alias: Some(1), + response_topic: None, + correlation_data: Some(b"1212".to_vec()), + subscription_identifiers: vec![1], + user_properties: vec![], + content_type: Some("Garbage".into()), + }, + payload: vec![], + }) +} + pub fn create_puback_packet(packet_identifier: u16) -> Packet { Packet::PubAck(PubAck { packet_identifier, diff --git a/mqrstt/src/tokio/stream/mod.rs b/mqrstt/src/tokio/stream/mod.rs index 463d109..48071d3 100644 --- a/mqrstt/src/tokio/stream/mod.rs +++ b/mqrstt/src/tokio/stream/mod.rs @@ -13,8 +13,7 @@ use tracing::trace; use crate::packets::ConnAck; use crate::packets::{ error::ReadBytes, - ConnAckReasonCode, - {FixedHeader, Packet}, + ConnAckReasonCode, {FixedHeader, Packet}, }; use crate::{connect_options::ConnectOptions, error::ConnectionError}; @@ -143,14 +142,20 @@ where } pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - packet.write(&mut self.write_buffer)?; + match packet.async_write(&mut self.stream).await { + Ok(_) => (), + Err(err) => { + return match err { + crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), + crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), + } + } + } #[cfg(feature = "logs")] trace!("Sending packet {}", packet); - self.stream.write_all(&self.write_buffer[..]).await?; self.stream.flush().await?; - self.write_buffer.clear(); Ok(()) } From fe0b6ba1cf8814886df0599165ff90bc059d0a8c Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:56:39 +0100 Subject: [PATCH 15/45] Small cargo changes etc --- .gitignore | 3 ++- Cargo.toml | 1 + rust-toolchain | 2 -- rustfmt.toml | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) delete mode 100644 rust-toolchain diff --git a/.gitignore b/.gitignore index 0b6e4ba..0a11e9d 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ examples/tokio_tls/Cargo.lock examples/smol_tls/Cargo.lock .vscode/** -Cargo.lock \ No newline at end of file +Cargo.lock +test.py \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 56f4f9f..a62cdfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,4 +2,5 @@ members = [ "mqrstt", + "fuzz", ] \ No newline at end of file diff --git a/rust-toolchain b/rust-toolchain deleted file mode 100644 index 9611677..0000000 --- a/rust-toolchain +++ /dev/null @@ -1,2 +0,0 @@ -[toolchain] -channel = "1.82.0" \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml index c9ce889..f494da5 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,4 +1,4 @@ -# unstable_features = true -# brace_style = "PreferSameLine" -# control_brace_style = "ClosingNextLine" +# unstable_features = true +# brace_style = "PreferSameLine" +# control_brace_style = "ClosingNextLine" max_width = 200 \ No newline at end of file From 4014236ed620d7ac2435515bcddff40b3acef640 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:08:11 +0100 Subject: [PATCH 16/45] Fix bug not wiring variable length of properties --- mqrstt/src/packets/macros/properties_macros.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index 3befb6d..d845d9f 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -37,6 +37,7 @@ macro_rules! define_properties { fn async_write(&self, stream: &mut S) -> impl std::future::Future> { async move { let mut bytes_writen = 0; + $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; $( $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); )* From e343b71870e463d7e1c9ede1329500dbac4df55e Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:09:51 +0100 Subject: [PATCH 17/45] fix: write packet id in puback pubcomp pubrec and pubrel --- mqrstt/src/packets/puback/mod.rs | 4 +++- mqrstt/src/packets/pubcomp/mod.rs | 4 +++- mqrstt/src/packets/pubrec/mod.rs | 4 +++- mqrstt/src/packets/pubrel/mod.rs | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/mqrstt/src/packets/puback/mod.rs b/mqrstt/src/packets/puback/mod.rs index 7ab0ba0..3dd5559 100644 --- a/mqrstt/src/packets/puback/mod.rs +++ b/mqrstt/src/packets/puback/mod.rs @@ -105,7 +105,9 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_written_bytes = 0; + let mut total_written_bytes = 2; + self.packet_identifier.async_write(stream).await?; + if self.reason_code == PubAckReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 229036e..67e3450 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -129,7 +129,9 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_writen_bytes = 0; + let mut total_writen_bytes = 2; + self.packet_identifier.async_write(stream).await?; + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { return Ok(total_writen_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index aa36eaf..9fad939 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -120,7 +120,9 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_writen_bytes = 0; + let mut total_writen_bytes = 2; + self.packet_identifier.async_write(stream).await?; + if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { return Ok(total_writen_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index 6e42a76..4406536 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -115,7 +115,9 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_writen_bytes = 0; + let mut total_writen_bytes = 2; + self.packet_identifier.async_write(stream).await?; + if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { return Ok(total_writen_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { From df0a2fcce8cc3f50c85567a6d12f13c479441b4b Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:10:38 +0100 Subject: [PATCH 18/45] move disconnect async packet write --- mqrstt/src/packets/disconnect/mod.rs | 35 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/mqrstt/src/packets/disconnect/mod.rs b/mqrstt/src/packets/disconnect/mod.rs index d8c1211..6ace759 100644 --- a/mqrstt/src/packets/disconnect/mod.rs +++ b/mqrstt/src/packets/disconnect/mod.rs @@ -40,23 +40,6 @@ where } } -impl crate::packets::mqtt_trait::PacketAsyncWrite for Disconnect -where - S: tokio::io::AsyncWrite + Unpin, -{ - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - use crate::packets::mqtt_trait::MqttAsyncWrite; - async move { - let mut total_written_bytes = 0; - if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { - total_written_bytes += self.reason_code.async_write(stream).await?; - total_written_bytes += self.properties.async_write(stream).await?; - } - Ok(total_written_bytes) - } - } -} - impl PacketRead for Disconnect { fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let reason_code; @@ -81,6 +64,24 @@ impl PacketWrite for Disconnect { Ok(()) } } + +impl crate::packets::mqtt_trait::PacketAsyncWrite for Disconnect +where + S: tokio::io::AsyncWrite + Unpin, +{ + fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + use crate::packets::mqtt_trait::MqttAsyncWrite; + async move { + let mut total_written_bytes = 0; + if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; + } + Ok(total_written_bytes) + } + } +} + impl WireLength for Disconnect { fn wire_len(&self) -> usize { if self.reason_code != DisconnectReasonCode::NormalDisconnection || self.properties.wire_len() != 0 { From 14cc249cc44ecd43e782009b4506ab745c9e78f1 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:11:41 +0100 Subject: [PATCH 19/45] fix reading publish packet --- mqrstt/src/packets/publish/mod.rs | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index bfb17a3..71dc8a3 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -114,13 +114,6 @@ where let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; total_read_bytes += properties_read_bytes; - if total_read_bytes != remaining_length { - return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { - read: total_read_bytes, - remaining_length: remaining_length, - })); - } - if total_read_bytes > remaining_length { return Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)); } From b1f84bf230f77d7e75cd8807b674b2098ece5a42 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:13:12 +0100 Subject: [PATCH 20/45] adjust function names of reading and writing packets --- mqrstt/src/packets/connack/mod.rs | 2 +- mqrstt/src/packets/mod.rs | 55 ++++++----- mqrstt/src/smol/stream.rs | 7 +- mqrstt/src/tests/test_bytes.rs | 10 +- mqrstt/src/tests/test_packets.rs | 2 +- mqrstt/src/tokio/network.rs | 71 ++++++-------- mqrstt/src/tokio/stream/mod.rs | 135 +++++++------------------- mqrstt/src/tokio/stream/read_half.rs | 38 +------- mqrstt/src/tokio/stream/write_half.rs | 31 ++++-- 9 files changed, 129 insertions(+), 222 deletions(-) diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 23a49f7..06c400b 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -216,7 +216,7 @@ mod tests { p1.write(&mut buf).unwrap(); - let p2 = Packet::read_from_buffer(&mut buf).unwrap(); + let p2 = Packet::read(&mut buf).unwrap(); assert_eq!(p1, p2); } diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 208ae19..3846052 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -293,7 +293,7 @@ impl Packet { Ok(written) } - pub(crate) fn read(header: FixedHeader, buf: Bytes) -> Result { + pub(crate) fn read_packet(header: FixedHeader, buf: Bytes) -> Result { let packet = match header.packet_type { PacketType::Connect => Packet::Connect(Connect::read(header.flags, header.remaining_length, buf)?), PacketType::ConnAck => Packet::ConnAck(ConnAck::read(header.flags, header.remaining_length, buf)?), @@ -314,7 +314,7 @@ impl Packet { Ok(packet) } - pub(crate) async fn async_read(header: FixedHeader, stream: &mut S) -> Result + async fn async_read_packet(header: FixedHeader, stream: &mut S) -> Result where S: tokio::io::AsyncRead + Unpin, { @@ -338,16 +338,19 @@ impl Packet { Ok(packet) } - pub async fn async_read_from_buffer(stream: &mut S) -> Result + pub async fn async_read(stream: &mut S) -> Result where S: tokio::io::AsyncRead + Unpin, { let (header, _) = FixedHeader::async_read(stream).await?; - Ok(Packet::async_read(header, stream).await?) + #[cfg(feature = "logs")] + tracing::trace!("Read packet header: {:?}", header); + + Ok(Packet::async_read_packet(header, stream).await?) } - pub fn read_from_buffer(buffer: &mut BytesMut) -> Result> { + pub fn read(buffer: &mut BytesMut) -> Result> { use bytes::Buf; use error::ReadBytes; @@ -359,7 +362,7 @@ impl Packet { let buf = buffer.split_to(header.remaining_length); - Ok(Packet::read(header, buf.into())?) + Ok(Packet::read_packet(header, buf.into())?) } } @@ -460,26 +463,26 @@ mod tests { fn test_write_read_write_read_cases(#[case] packet: Packet) { let mut buffer = BytesMut::new(); packet.write(&mut buffer).unwrap(); - let res1 = Packet::read_from_buffer(&mut buffer).unwrap(); + let res1 = Packet::read(&mut buffer).unwrap(); let mut buffer = BytesMut::new(); res1.write(&mut buffer).unwrap(); - let res2 = Packet::read_from_buffer(&mut buffer).unwrap(); + let res2 = Packet::read(&mut buffer).unwrap(); assert_eq!(res1, res2); } #[rstest::rstest] - #[case(disconnect_case())] - #[case(ping_req_case())] - #[case(ping_resp_case())] - #[case(publish_case())] - #[case(pubrel_case())] - #[case(pubrel_smallest_case())] + #[case::disconnect(disconnect_case())] + #[case::ping_req(ping_req_case())] + #[case::ping_resp(ping_resp_case())] + #[case::publish(publish_case())] + #[case::pubrel(pubrel_case())] + #[case::pubrel_smallest(pubrel_smallest_case())] fn test_read_write_cases(#[case] (bytes, expected_packet): (&[u8], Packet)) { let mut buffer = BytesMut::from_iter(bytes); - let res = Packet::read_from_buffer(&mut buffer); + let res = Packet::read(&mut buffer); assert!(res.is_ok()); @@ -495,17 +498,17 @@ mod tests { } #[rstest::rstest] - #[case(disconnect_case())] - #[case(ping_req_case())] - #[case(ping_resp_case())] - #[case(publish_case())] - #[case(pubrel_case())] - #[case(pubrel_smallest_case())] + #[case::disconnect(disconnect_case())] + #[case::ping_req(ping_req_case())] + #[case::ping_resp(ping_resp_case())] + #[case::publish(publish_case())] + #[case::pubrel(pubrel_case())] + #[case::pubrel_smallest(pubrel_smallest_case())] #[tokio::test] async fn test_async_read_write(#[case] (mut bytes, expected_packet): (&[u8], Packet)) { - // let mut buffer = BytesMut::from(bytes); + let input = bytes.to_vec(); - let res = Packet::async_read_from_buffer(&mut bytes).await; + let res = Packet::async_read(&mut bytes).await; dbg!(&res); assert!(res.is_ok()); @@ -514,9 +517,11 @@ mod tests { assert_eq!(packet, expected_packet); - // packet.write(&mut buffer).unwrap(); + let mut out = Vec::with_capacity(1000); + + packet.async_write(&mut out).await.unwrap(); - // assert_eq!() + assert_eq!(out, input) } // #[rstest::rstest] diff --git a/mqrstt/src/smol/stream.rs b/mqrstt/src/smol/stream.rs index efa3a8f..2f1cbd8 100644 --- a/mqrstt/src/smol/stream.rs +++ b/mqrstt/src/smol/stream.rs @@ -10,8 +10,7 @@ use tracing::trace; use crate::packets::ConnAck; use crate::packets::{ error::ReadBytes, - ConnAckReasonCode, - {FixedHeader, Packet}, + ConnAckReasonCode, {FixedHeader, Packet}, }; use crate::{connect_options::ConnectOptions, error::ConnectionError}; @@ -40,7 +39,7 @@ impl Stream { self.read_buffer.advance(header_length); let buf = self.read_buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; + let read_packet = Packet::read_packet(header, buf.into())?; #[cfg(feature = "logs")] trace!("Read packet from network {}", read_packet); @@ -96,7 +95,7 @@ where let buf = self.read_buffer.split_to(header.remaining_length); - return Packet::read(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); + return Packet::read_packet(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); } } diff --git a/mqrstt/src/tests/test_bytes.rs b/mqrstt/src/tests/test_bytes.rs index 04b21af..83fcc41 100644 --- a/mqrstt/src/tests/test_bytes.rs +++ b/mqrstt/src/tests/test_bytes.rs @@ -52,12 +52,12 @@ pub fn subscribe_packet() -> Vec { fn publish_packet_test(#[case] bytes: Vec) { let mut read_buffer = BytesMut::from_iter(bytes.iter()); let mut write_buffer = BytesMut::new(); - let packet = Packet::read_from_buffer(&mut read_buffer).unwrap(); + let packet = Packet::read(&mut read_buffer).unwrap(); packet.write(&mut write_buffer).unwrap(); assert_eq!(bytes.len(), write_buffer.len()); - let packet_from_write_buffer = Packet::read_from_buffer(&mut write_buffer).unwrap(); + let packet_from_write_buffer = Packet::read(&mut write_buffer).unwrap(); assert_eq!(packet, packet_from_write_buffer); } @@ -68,7 +68,7 @@ fn test_connect() { let mut read_buffer = BytesMut::from_iter(bytes.iter()); let mut write_buffer = BytesMut::new(); - let packet = Packet::read_from_buffer(&mut read_buffer).unwrap(); + let packet = Packet::read(&mut read_buffer).unwrap(); packet.write(&mut write_buffer).unwrap(); if let Packet::Connect(p) = &packet { @@ -79,7 +79,7 @@ fn test_connect() { assert_eq!(bytes.len(), write_buffer.len()); assert_eq!(bytes, write_buffer.to_vec()); - let packet_from_write_buffer = Packet::read_from_buffer(&mut write_buffer).unwrap(); + let packet_from_write_buffer = Packet::read(&mut write_buffer).unwrap(); assert_eq!(packet, packet_from_write_buffer); } @@ -93,7 +93,7 @@ fn test_connect() { fn test_equal_read_write_packet_from_bytes(#[case] bytes: Vec) { let mut read_buffer = BytesMut::from_iter(bytes.iter()); let mut write_buffer = BytesMut::new(); - let packet = Packet::read_from_buffer(&mut read_buffer).unwrap(); + let packet = Packet::read(&mut read_buffer).unwrap(); packet.write(&mut write_buffer).unwrap(); assert_eq!(bytes, write_buffer.to_vec()); diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index df93161..c8351e5 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -292,7 +292,7 @@ fn test_equal_write_read(#[case] packet: Packet) { packet.write(&mut buffer).unwrap(); - let read_packet = Packet::read_from_buffer(&mut buffer).unwrap(); + let read_packet = Packet::read(&mut buffer).unwrap(); assert_eq!(packet, read_packet); } diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index d526988..f78a174 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -125,7 +125,7 @@ where } = self; let mut await_pingresp = None; - let mut outgoing_packet_buffer = Vec::new(); + // let mut outgoing_packet_buffer = Vec::new(); loop { let sleep; @@ -137,40 +137,33 @@ where if let Some(stream) = network { tokio::select! { - res = stream.read_bytes() => { - res?; - loop{ - let packet = match stream.parse_message().await { - Err(ReadBytes::Err(err)) => return Err(err), - Err(ReadBytes::InsufficientBytes(_)) => break, - Ok(packet) => packet, - }; - match packet{ - Packet::PingResp => { - SequentialHandler::call_handler_await(handler, packet).await; - await_pingresp = None; - }, - Packet::Disconnect(_) => { - SequentialHandler::call_handler_await(handler, packet).await; - return Ok(NetworkStatus::IncomingDisconnect); - } - packet => { - match state_handler.handle_incoming_packet(&packet)? { - (maybe_reply_packet, true) => { - SequentialHandler::call_handler_await(handler, packet).await; - if let Some(reply_packet) = maybe_reply_packet { - outgoing_packet_buffer.push(reply_packet); - } - }, - (Some(reply_packet), false) => { - outgoing_packet_buffer.push(reply_packet); - }, - (None, false) => (), - } + res = stream.read() => { + let packet = res?; + match packet{ + Packet::PingResp => { + SequentialHandler::call_handler_await(handler, packet).await; + await_pingresp = None; + }, + Packet::Disconnect(_) => { + SequentialHandler::call_handler_await(handler, packet).await; + return Ok(NetworkStatus::IncomingDisconnect); + } + packet => { + match state_handler.handle_incoming_packet(&packet)? { + (maybe_reply_packet, true) => { + SequentialHandler::call_handler_await(handler, packet).await; + if let Some(reply_packet) = maybe_reply_packet { + stream.write(&reply_packet).await?; + *last_network_action = Instant::now(); + } + }, + (Some(reply_packet), false) => { + stream.write(&reply_packet).await?; + *last_network_action = Instant::now(); + }, + (None, false) => (), } } - stream.write_all(&mut outgoing_packet_buffer).await?; - *last_network_action = Instant::now(); } }, outgoing = to_network_r.recv() => { @@ -270,7 +263,7 @@ pub struct NetworkReader { impl NetworkReader where N: HandlerExt, - S: tokio::io::AsyncReadExt + Sized + Unpin + Send + 'static, + S: tokio::io::AsyncRead + Sized + Unpin + Send + 'static, { /// Runs the read half of the mqtt connection. /// Continuously loops until disconnect or error. @@ -289,16 +282,8 @@ where } async fn read(&mut self) -> Result { while self.run_signal.load(std::sync::atomic::Ordering::Acquire) { - let _ = self.read_stream.read_bytes().await?; loop { - let packet = match self.read_stream.parse_message() { - Err(ReadBytes::Err(err)) => return Err(err), - Err(ReadBytes::InsufficientBytes(_)) => { - break; - } - Ok(packet) => packet, - }; - + let packet = self.read_stream.read().await?; match packet { Packet::PingResp => { N::call_handler(&mut self.handler, packet).await; diff --git a/mqrstt/src/tokio/stream/mod.rs b/mqrstt/src/tokio/stream/mod.rs index 48071d3..ced07d4 100644 --- a/mqrstt/src/tokio/stream/mod.rs +++ b/mqrstt/src/tokio/stream/mod.rs @@ -10,6 +10,7 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; #[cfg(feature = "logs")] use tracing::trace; +use crate::packets::error::WriteError; use crate::packets::ConnAck; use crate::packets::{ error::ReadBytes, @@ -22,36 +23,7 @@ use self::write_half::WriteStream; #[derive(Debug)] pub struct Stream { - pub stream: S, - - /// Input buffer - const_buffer: [u8; 4096], - - /// Write buffer - read_buffer: BytesMut, - - /// Write buffer - write_buffer: BytesMut, -} - -impl Stream { - pub async fn parse_message(&mut self) -> Result> { - let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; - - if header.remaining_length + header_length > self.read_buffer.len() { - return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); - } - - self.read_buffer.advance(header_length); - - let buf = self.read_buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - - #[cfg(feature = "logs")] - trace!("Read packet from network {}", read_packet); - - Ok(read_packet) - } + stream: S, } impl Stream @@ -59,33 +31,25 @@ where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, { pub fn split(self) -> (ReadStream, WriteStream) { - let Self { - stream, - const_buffer, - read_buffer, - write_buffer, - } = self; + let Self { stream } = self; let (read_stream, write_stream) = tokio::io::split(stream); - (ReadStream::new(read_stream, const_buffer, read_buffer), WriteStream::new(write_stream, write_buffer)) + (ReadStream::new(read_stream), WriteStream::new(write_stream)) } pub async fn connect(options: &ConnectOptions, stream: S) -> Result<(Self, ConnAck), ConnectionError> { - let mut s = Self { - stream, - const_buffer: [0; 4096], - read_buffer: BytesMut::new(), - write_buffer: BytesMut::new(), - }; + let mut s = Self { stream }; let connect = options.create_connect_from_options(); s.write(&connect).await?; - let packet = s.read().await?; + let packet = Packet::async_read(&mut s.stream).await?; if let Packet::ConnAck(con) = packet { if con.reason_code == ConnAckReasonCode::Success { + #[cfg(feature = "logs")] + trace!("Connected to server"); Ok((s, con)) } else { Err(ConnectionError::ConnectionRefused(con.reason_code)) @@ -95,51 +59,32 @@ where } } - pub async fn read(&mut self) -> io::Result { - loop { - let (header, header_length) = match FixedHeader::read_fixed_header(self.read_buffer.iter()) { - Ok(header) => header, - Err(ReadBytes::InsufficientBytes(required_len)) => { - self.read_required_bytes(required_len).await?; - continue; - } - Err(ReadBytes::Err(err)) => return Err(Error::new(ErrorKind::InvalidData, err)), - }; - - if header_length + header.remaining_length > self.read_buffer.len() { - self.read_required_bytes(header.remaining_length - self.read_buffer.len()).await?; - } - - self.read_buffer.advance(header_length); - - let buf = self.read_buffer.split_to(header.remaining_length); - - return Packet::read(header, buf.into()).map_err(|err| Error::new(ErrorKind::InvalidData, err)); - } - } - - pub async fn read_bytes(&mut self) -> io::Result { - let read = self.stream.read(&mut self.const_buffer).await?; - if read == 0 { - Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) - } else { - self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); - Ok(read) - } + pub async fn read(&mut self) -> Result { + Ok(Packet::async_read(&mut self.stream).await?) } - /// Reads more than 'required' bytes to frame a packet into self.read buffer - pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { - let mut total_read = 0; - - loop { - let read = self.read_bytes().await?; - total_read += read; - if total_read >= required { - return Ok(total_read); - } - } - } + // pub async fn read_bytes(&mut self) -> io::Result { + // let read = self.stream.read(&mut self.const_buffer).await?; + // if read == 0 { + // Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) + // } else { + // self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); + // Ok(read) + // } + // } + + // /// Reads more than 'required' bytes to frame a packet into self.read buffer + // pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { + // let mut total_read = 0; + + // loop { + // let read = self.read_bytes().await?; + // total_read += read; + // if total_read >= required { + // return Ok(total_read); + // } + // } + // } pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { match packet.async_write(&mut self.stream).await { @@ -151,31 +96,21 @@ where } } } + self.stream.flush().await?; #[cfg(feature = "logs")] trace!("Sending packet {}", packet); - self.stream.flush().await?; Ok(()) } pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { - let writes = packets.drain(0..).map(|packet| { - packet.write(&mut self.write_buffer)?; - + for packet in packets { + let _ = packet.async_write(&mut self.stream).await; #[cfg(feature = "logs")] trace!("Sending packet {}", packet); - - Ok::<(), ConnectionError>(()) - }); - - for write in writes { - write?; } - - self.stream.write_all(&self.write_buffer[..]).await?; self.stream.flush().await?; - self.write_buffer.clear(); Ok(()) } } diff --git a/mqrstt/src/tokio/stream/read_half.rs b/mqrstt/src/tokio/stream/read_half.rs index 0d5be0b..764802a 100644 --- a/mqrstt/src/tokio/stream/read_half.rs +++ b/mqrstt/src/tokio/stream/read_half.rs @@ -14,47 +14,17 @@ use tracing::trace; #[derive(Debug)] pub struct ReadStream { stream: ReadHalf, - - /// Input buffer - const_buffer: [u8; 4096], - - /// Write buffer - read_buffer: BytesMut, } impl ReadStream where S: tokio::io::AsyncRead + Sized + Unpin, { - pub fn new(stream: ReadHalf, const_buffer: [u8; 4096], read_buffer: BytesMut) -> Self { - Self { stream, const_buffer, read_buffer } - } - - pub fn parse_message(&mut self) -> Result> { - let (header, header_length) = FixedHeader::read_fixed_header(self.read_buffer.iter())?; - - if header.remaining_length + header_length > self.read_buffer.len() { - return Err(ReadBytes::InsufficientBytes(header.remaining_length - self.read_buffer.len())); - } - - self.read_buffer.advance(header_length); - - let buf = self.read_buffer.split_to(header.remaining_length); - let read_packet = Packet::read(header, buf.into())?; - - #[cfg(feature = "logs")] - trace!("Read packet from network {}", read_packet); - - Ok(read_packet) + pub fn new(stream: ReadHalf) -> Self { + Self { stream } } - pub async fn read_bytes(&mut self) -> io::Result { - let read = self.stream.read(&mut self.const_buffer).await?; - if read == 0 { - Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) - } else { - self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); - Ok(read) - } + pub async fn read(&mut self) -> Result { + Ok(Packet::async_read(&mut self.stream).await?) } } diff --git a/mqrstt/src/tokio/stream/write_half.rs b/mqrstt/src/tokio/stream/write_half.rs index 9bc5fb4..6d161a0 100644 --- a/mqrstt/src/tokio/stream/write_half.rs +++ b/mqrstt/src/tokio/stream/write_half.rs @@ -8,15 +8,12 @@ use tracing::trace; #[derive(Debug)] pub struct WriteStream { - pub stream: WriteHalf, - - /// Write buffer - write_buffer: BytesMut, + stream: WriteHalf, } impl WriteStream { - pub fn new(stream: WriteHalf, write_buffer: BytesMut) -> Self { - Self { stream, write_buffer } + pub fn new(stream: WriteHalf) -> Self { + Self { stream } } } @@ -25,14 +22,30 @@ where S: tokio::io::AsyncWrite + Sized + Unpin, { pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - packet.write(&mut self.write_buffer)?; + match packet.async_write(&mut self.stream).await { + Ok(_) => (), + Err(err) => { + return match err { + crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), + crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), + } + } + } #[cfg(feature = "logs")] trace!("Sending packet {}", packet); - self.stream.write_all(&self.write_buffer[..]).await?; self.stream.flush().await?; - self.write_buffer.clear(); + Ok(()) + } + + pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { + for packet in packets { + let _ = packet.async_write(&mut self.stream).await; + #[cfg(feature = "logs")] + trace!("Sending packet {}", packet); + } + self.stream.flush().await?; Ok(()) } } From 133f2f2c85e11411ff3360792f56a622518d696f Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:13:46 +0100 Subject: [PATCH 21/45] Impl from ReadError and WriteError into ConnectionError --- mqrstt/src/error.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/mqrstt/src/error.rs b/mqrstt/src/error.rs index e4b018c..7651ceb 100644 --- a/mqrstt/src/error.rs +++ b/mqrstt/src/error.rs @@ -3,8 +3,8 @@ use std::io; use async_channel::{RecvError, SendError}; use crate::packets::{ - error::{DeserializeError, ReadBytes, SerializeError}, - ConnAckReasonCode, {Packet, PacketType}, + error::{DeserializeError, ReadBytes, ReadError, SerializeError, WriteError}, + ConnAckReasonCode, Packet, PacketType, }; /// Critical errors that can happen during the operation of the entire client @@ -42,6 +42,24 @@ pub enum ConnectionError { JoinError(#[from] tokio::task::JoinError), } +impl From for ConnectionError { + fn from(value: ReadError) -> Self { + match value { + ReadError::DeserializeError(deserialize_error) => ConnectionError::DeserializationError(deserialize_error), + ReadError::IoError(error) => ConnectionError::Io(error), + } + } +} + +impl From for ConnectionError { + fn from(value: WriteError) -> Self { + match value { + WriteError::SerializeError(error) => ConnectionError::SerializationError(error), + WriteError::IoError(error) => ConnectionError::Io(error), + } + } +} + /// Errors that the internal StateHandler can emit #[derive(Debug, Clone, thiserror::Error)] pub enum HandlerError { From 6cb9cd546200f1ad8e281ea351c4ec579fa71548 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:14:08 +0100 Subject: [PATCH 22/45] add first fuzzing target --- fuzz/.gitignore | 4 ++++ fuzz/Cargo.toml | 25 +++++++++++++++++++++++++ fuzz/fuzz_targets/fuzz_target_1.rs | 14 ++++++++++++++ 3 files changed, 43 insertions(+) create mode 100644 fuzz/.gitignore create mode 100644 fuzz/Cargo.toml create mode 100644 fuzz/fuzz_targets/fuzz_target_1.rs diff --git a/fuzz/.gitignore b/fuzz/.gitignore new file mode 100644 index 0000000..1a45eee --- /dev/null +++ b/fuzz/.gitignore @@ -0,0 +1,4 @@ +target +corpus +artifacts +coverage diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml new file mode 100644 index 0000000..adb4c3e --- /dev/null +++ b/fuzz/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "mqrstt-fuzz" +version = "0.0.0" +publish = false +edition = "2021" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +libfuzzer-sys = "0.4" + +bytes = "1" + +tokio = { version = "1", features = ["full"] } + +[dependencies.mqrstt] +path = "../mqrstt" + +[[bin]] +name = "fuzz_target_1" +path = "fuzz_targets/fuzz_target_1.rs" +test = false +doc = false +bench = false diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs new file mode 100644 index 0000000..1d3f536 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -0,0 +1,14 @@ +#![no_main] + +use libfuzzer_sys::fuzz_target; + +#[tokio::main(flavor = "current_thread")] +async fn test(mut data: &[u8]) { + mqrstt::packets::Packet::async_read(&mut data).await; +} + +fuzz_target!(|data: &[u8]| { + // let mut packet = bytes::BytesMut::from(data); + // mqrstt::packets::Packet::read_from_buffer(&mut packet); + test(data); +}); From c0216b2b1c95352456d2ede052568d5088b10410 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:37:51 +0100 Subject: [PATCH 23/45] change AsyncReadExt to AsyncRead --- mqrstt/src/packets/auth/mod.rs | 2 +- mqrstt/src/packets/connack/mod.rs | 3 ++- mqrstt/src/packets/disconnect/mod.rs | 2 +- mqrstt/src/packets/mqtt_trait/mod.rs | 2 +- mqrstt/src/packets/puback/mod.rs | 6 +++--- mqrstt/src/packets/pubrel/mod.rs | 5 +++-- mqrstt/src/packets/subscribe/mod.rs | 2 +- 7 files changed, 12 insertions(+), 10 deletions(-) diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs index 32d207a..74a2107 100644 --- a/mqrstt/src/packets/auth/mod.rs +++ b/mqrstt/src/packets/auth/mod.rs @@ -8,7 +8,7 @@ pub use reason_code::AuthReasonCode; use bytes::Bytes; use super::{ - mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite, PacketAsyncRead, PacketAsyncWrite, PacketRead, PacketWrite, WireLength}, + mqtt_trait::{MqttAsyncRead, MqttAsyncWrite, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, VariableInteger, }; diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 06c400b..2ff8650 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -10,6 +10,7 @@ use super::{ VariableInteger, }; use bytes::{Buf, BufMut}; +use tokio::io::AsyncReadExt; /// ConnAck packet is sent by the server in response to a [`crate::packets::Connect`] packet. /// @@ -111,7 +112,7 @@ pub struct ConnAckFlags { impl MqttAsyncRead for ConnAckFlags where - S: tokio::io::AsyncReadExt + Unpin, + S: tokio::io::AsyncRead + Unpin, { fn async_read(stream: &mut S) -> impl std::future::Future> { async move { diff --git a/mqrstt/src/packets/disconnect/mod.rs b/mqrstt/src/packets/disconnect/mod.rs index 6ace759..e077fb3 100644 --- a/mqrstt/src/packets/disconnect/mod.rs +++ b/mqrstt/src/packets/disconnect/mod.rs @@ -20,7 +20,7 @@ pub struct Disconnect { impl PacketAsyncRead for Disconnect where - S: tokio::io::AsyncReadExt + Unpin, + S: tokio::io::AsyncRead + Unpin, { async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { if remaining_length == 0 { diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs index 68d32ce..50ce760 100644 --- a/mqrstt/src/packets/mqtt_trait/mod.rs +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -10,7 +10,7 @@ pub(crate) trait PacketRead: Sized { pub(crate) trait PacketAsyncRead: Sized where - S: tokio::io::AsyncReadExt + Unpin, + S: tokio::io::AsyncRead + Unpin, { fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl Future>; } diff --git a/mqrstt/src/packets/puback/mod.rs b/mqrstt/src/packets/puback/mod.rs index 3dd5559..55f5886 100644 --- a/mqrstt/src/packets/puback/mod.rs +++ b/mqrstt/src/packets/puback/mod.rs @@ -1,13 +1,13 @@ mod reason_code; pub use reason_code::PubAckReasonCode; -use bytes::BufMut; - use super::{ error::DeserializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, PacketType, PropertyType, VariableInteger, }; +use bytes::BufMut; +use tokio::io::AsyncReadExt; /// The PUBACK Packet is the response to a PUBLISH Packet with QoS 1. /// Both the server and client can send a PUBACK packet. @@ -20,7 +20,7 @@ pub struct PubAck { impl PacketAsyncRead for PubAck where - S: tokio::io::AsyncReadExt + Unpin, + S: tokio::io::AsyncRead + Unpin, { async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { let packet_identifier = stream.read_u16().await?; diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index 4406536..c97d285 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -5,6 +5,7 @@ mod properties; pub use properties::PubRelProperties; use bytes::BufMut; +use tokio::io::AsyncReadExt; use super::{ error::{DeserializeError, ReadError}, @@ -58,7 +59,7 @@ impl PacketRead for PubRel { impl PacketAsyncRead for PubRel where - S: tokio::io::AsyncReadExt + Unpin, + S: tokio::io::AsyncRead + Unpin, { async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), ReadError> { let mut total_read_bytes = 0; @@ -352,7 +353,7 @@ mod tests { let mut stream = &*buf; // flags can be 0 because not used. // remaining_length must be at least 4 - let (p_ack, read_bytes) = PubRel::async_read(0, buf.len(), &mut stream).await.unwrap(); + let (p_ack, _) = PubRel::async_read(0, buf.len(), &mut stream).await.unwrap(); let mut result = BytesMut::new(); p_ack.write(&mut result).unwrap(); diff --git a/mqrstt/src/packets/subscribe/mod.rs b/mqrstt/src/packets/subscribe/mod.rs index e97cd36..28f4899 100644 --- a/mqrstt/src/packets/subscribe/mod.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -59,7 +59,7 @@ impl PacketRead for Subscribe { impl PacketAsyncRead for Subscribe where - S: tokio::io::AsyncReadExt + Unpin, + S: tokio::io::AsyncRead + Unpin, { fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { async move { From cdcf93acfa17fef104514c9e2fd1902c9f43a34b Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:38:42 +0100 Subject: [PATCH 24/45] Remove concurrent tokio support The approach is complicated and usecase minimal. Thus removing --- mqrstt/src/lib.rs | 106 +------- mqrstt/src/tokio/mod.rs | 87 ------ mqrstt/src/tokio/network.rs | 250 +----------------- mqrstt/src/tokio/{stream/mod.rs => stream.rs} | 49 +--- mqrstt/src/tokio/stream/read_half.rs | 30 --- mqrstt/src/tokio/stream/write_half.rs | 51 ---- 6 files changed, 23 insertions(+), 550 deletions(-) rename mqrstt/src/tokio/{stream/mod.rs => stream.rs} (60%) delete mode 100644 mqrstt/src/tokio/stream/read_half.rs delete mode 100644 mqrstt/src/tokio/stream/write_half.rs diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index 65b4462..ccd133c 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -194,14 +194,10 @@ impl NetworkBuilder { #[cfg(feature = "tokio")] impl NetworkBuilder where - H: AsyncEventHandlerMut, - S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, + H: AsyncEventHandler, + S: ::tokio::io::AsyncRead + ::tokio::io::AsyncWrite + Sized + Unpin, { - /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] - /// This network is supposed to be ran on a single task/thread. The read and write operations happen one after the other. - /// This approach does not give the most speed in terms of reading and writing but provides a simple and easy to use client with low overhead for low throughput clients. - /// - /// For more throughput: [`NetworkBuilder::tokio_concurrent_network`] + /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncRead`] and [`::tokio::io::AsyncWrite`] /// /// # Example /// ``` @@ -212,9 +208,9 @@ where /// ::new_from_options(options) /// .tokio_sequential_network(); /// ``` - pub fn tokio_sequential_network(self) -> (tokio::Network, MqttClient) + pub fn tokio_network(self) -> (tokio::Network, MqttClient) where - H: AsyncEventHandlerMut, + H: AsyncEventHandler, { let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); @@ -230,45 +226,13 @@ where } } -#[cfg(feature = "tokio")] -impl NetworkBuilder -where - H: AsyncEventHandler, - S: ::tokio::io::AsyncReadExt + ::tokio::io::AsyncWriteExt + Sized + Unpin, -{ - /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] - /// # Example - /// - /// ``` - /// use mqrstt::ConnectOptions; - /// - /// let options = ConnectOptions::new("ExampleClient"); - /// let (mut network, client) = mqrstt::NetworkBuilder::<(), tokio::net::TcpStream> - /// ::new_from_options(options) - /// .tokio_concurrent_network(); - /// ``` - pub fn tokio_concurrent_network(self) -> (tokio::Network, MqttClient) { - let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); - - let (apkids, apkids_r) = available_packet_ids::AvailablePacketIds::new(self.options.send_maximum()); - - let max_packet_size = self.options.maximum_packet_size(); - - let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); - - let network = tokio::Network::new(self.options, to_network_r, apkids); - - (network, client) - } -} - #[cfg(feature = "smol")] impl NetworkBuilder where - H: AsyncEventHandlerMut, - S: ::smol::io::AsyncReadExt + ::smol::io::AsyncWriteExt + Sized + Unpin, + H: AsyncEventHandler, + S: ::smol::io::AsyncRead + ::smol::io::AsyncWrite + Sized + Unpin, { - /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncReadExt`] and [`::tokio::io::AsyncWriteExt`] + /// Creates the needed components to run the MQTT client using a stream that implements [`::tokio::io::AsyncRead`] and [`::tokio::io::AsyncWrite`] /// ``` /// let (mut network, client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream> /// ::new_from_client_id("ExampleClient") @@ -459,60 +423,6 @@ mod tokio_lib_test { use std::{sync::Arc, time::Duration}; - use crate::ConnectOptions; - - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] - async fn test_tokio_tcp() { - use std::hint::black_box; - - use crate::NetworkBuilder; - - let client_id: String = crate::random_chars() + "_TokioTcpPingPong"; - - let (mut network, client) = NetworkBuilder::new_from_client_id(client_id).tokio_concurrent_network(); - - let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); - - let mut pingpong = Arc::new(PingPong::new(client.clone())); - - network.connect(stream, &mut pingpong).await.unwrap(); - - let topic = crate::random_chars() + "_mqrstt"; - - client.subscribe((topic.as_str(), QoS::ExactlyOnce)).await.unwrap(); - - tokio::time::sleep(Duration::from_secs(5)).await; - - let (read, write) = network.split(pingpong.clone()).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_result, write_result, _) = tokio::join!(read_handle, write_handle, async { - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".repeat(500)).await.unwrap(); - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".to_vec()).await.unwrap(); - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".to_vec()).await.unwrap(); - client.publish(topic.as_str(), QoS::ExactlyOnce, false, b"ping".repeat(500)).await.unwrap(); - - client.unsubscribe(topic.as_str()).await.unwrap(); - - for _ in 0..30 { - tokio::time::sleep(Duration::from_secs(1)).await; - if pingpong.number.load(std::sync::atomic::Ordering::SeqCst) == 4 { - break; - } - } - - client.disconnect().await.unwrap(); - }); - - let write_result = write_result.unwrap(); - assert!(write_result.is_ok()); - assert_eq!(crate::NetworkStatus::OutgoingDisconnect, write_result.unwrap()); - assert_eq!(4, pingpong.number.load(std::sync::atomic::Ordering::SeqCst)); - let _ = black_box(read_result); - } - // #[tokio::test] // async fn test_tokio_ping_req() { // let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); diff --git a/mqrstt/src/tokio/mod.rs b/mqrstt/src/tokio/mod.rs index 234ec18..4aef142 100644 --- a/mqrstt/src/tokio/mod.rs +++ b/mqrstt/src/tokio/mod.rs @@ -2,91 +2,4 @@ mod stream; pub(crate) mod network; -use futures::Future; pub use network::Network; -pub use network::{NetworkReader, NetworkWriter}; - -use crate::error::ConnectionError; -use crate::packets::Packet; - -/// This empty struct is used to indicate the handling of messages goes via a mutable handler. -/// Only a single mutable reference can exist at once. -/// Thus this kind is not for concurrent message handling but for concurrent TCP read and write operations. -pub struct SequentialHandler; - -/// This empty struct is used to indicate a (tokio) task based handling of messages. -/// Per incoming message a task is spawned to call the handler. -/// -/// This kind of handler is used for both concurrent message handling and concurrent TCP read and write operations. -pub struct ConcurrentHandler; - -pub trait HandlerExt: Sized { - /// Should call the handler in the fashion of the handler. - /// (e.g. spawn a task if or await the handle call) - fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send; - - /// Should call the handler and await it - fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send; - - /// Should call the handler in the fashion of the handler. - /// (e.g. spawn a task if or await the handle call) - /// The reply (e.g. an ACK) to the original packet is only send when the handle call has completed - fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send - where - S: Send; -} - -impl HandlerExt for SequentialHandler { - #[inline] - fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - handler.handle(incoming_packet) - } - #[inline] - fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - handler.handle(incoming_packet) - } - fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send - where - S: Send, - { - async { - network.handler.handle(incoming_packet).await; - if let Some(reply_packet) = reply_packet { - network.to_writer_s.send(reply_packet).await?; - } - Ok(()) - } - } -} - -impl HandlerExt for ConcurrentHandler { - fn call_handler(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - let handler_clone = handler.clone(); - tokio::spawn(async move { - handler_clone.handle(incoming_packet).await; - }); - std::future::ready(()) - } - #[inline] - fn call_handler_await(handler: &mut H, incoming_packet: Packet) -> impl Future + Send { - handler.handle(incoming_packet) - } - - fn call_handler_with_reply(network: &mut NetworkReader, incoming_packet: Packet, reply_packet: Option) -> impl Future> + Send - where - S: Send, - { - let handler_clone = network.handler.clone(); - let write_channel_clone = network.to_writer_s.clone(); - - network.join_set.spawn(async move { - handler_clone.handle(incoming_packet).await; - if let Some(reply_packet) = reply_packet { - write_channel_clone.send(reply_packet).await?; - } - Ok(()) - }); - - std::future::ready(Ok(())) - } -} diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index f78a174..84efd01 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -1,29 +1,25 @@ -use async_channel::{Receiver, Sender}; -use tokio::task::JoinSet; +use async_channel::Receiver; use std::marker::PhantomData; -use std::sync::atomic::AtomicBool; + use std::sync::Arc; use std::time::{Duration, Instant}; use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::ConnectionError; -use crate::packets::error::ReadBytes; use crate::packets::DisconnectReasonCode; use crate::packets::{Disconnect, Packet, PacketType}; -use crate::{AsyncEventHandlerMut, NetworkStatus, StateHandler}; +use crate::{AsyncEventHandler, NetworkStatus, StateHandler}; use super::stream::Stream; -use super::{HandlerExt, SequentialHandler}; /// [`Network`] reads and writes to the network based on tokios [`::tokio::io::AsyncReadExt`] [`::tokio::io::AsyncWriteExt`]. /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. /// The most import thing to remember is that you have to provide a new stream after the previous has failed. /// (i.e. you need to reconnect after any expected or unexpected disconnect). -pub struct Network { - handler_helper: PhantomData, +pub struct Network { handler: PhantomData, network: Option>, @@ -35,10 +31,9 @@ pub struct Network { to_network_r: Receiver, } -impl Network { +impl Network { pub(crate) fn new(options: ConnectOptions, to_network_r: Receiver, apkids: AvailablePacketIds) -> Self { Self { - handler_helper: PhantomData, handler: PhantomData, network: None, @@ -54,10 +49,9 @@ impl Network { } } -/// Tokio impl -impl Network +impl Network where - N: HandlerExt, + H: AsyncEventHandler, S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// Initializes an MQTT connection with the provided configuration an stream @@ -73,7 +67,7 @@ where } let packets = self.state_handler.handle_incoming_connack(&conn_ack)?; - N::call_handler_await(handler, Packet::ConnAck(conn_ack)).await; + handler.handle(Packet::ConnAck(conn_ack)).await; if let Some(mut packets) = packets { network.write_all(&mut packets).await?; self.last_network_action = Instant::now(); @@ -85,10 +79,9 @@ where } } -impl Network +impl Network where - H: AsyncEventHandlerMut, - SequentialHandler: HandlerExt, + H: AsyncEventHandler, S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// A single call to run will perform one of three tasks: @@ -119,7 +112,6 @@ where last_network_action, perform_keep_alive, to_network_r, - handler_helper: _, handler: _, state_handler, } = self; @@ -141,17 +133,17 @@ where let packet = res?; match packet{ Packet::PingResp => { - SequentialHandler::call_handler_await(handler, packet).await; + handler.handle(packet).await; await_pingresp = None; }, Packet::Disconnect(_) => { - SequentialHandler::call_handler_await(handler, packet).await; + handler.handle(packet).await; return Ok(NetworkStatus::IncomingDisconnect); } packet => { match state_handler.handle_incoming_packet(&packet)? { (maybe_reply_packet, true) => { - SequentialHandler::call_handler_await(handler, packet).await; + handler.handle(packet).await; if let Some(reply_packet) = maybe_reply_packet { stream.write(&reply_packet).await?; *last_network_action = Instant::now(); @@ -197,219 +189,3 @@ where } } } - -impl Network -where - S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, -{ - /// Creates both read and write tasks to run this them in parallel. - /// If you want to run concurrently (not parallel) the [`Self::run`] method is a better aproach! - pub fn split(&mut self, handler: H) -> Result<(NetworkReader, NetworkWriter), ConnectionError> { - if self.network.is_none() { - return Err(ConnectionError::NoNetwork)?; - } - - match self.network.take() { - Some(network) => { - let (read_stream, write_stream) = network.split(); - let run_signal = Arc::new(AtomicBool::new(true)); - let (to_writer_s, to_writer_r) = async_channel::bounded(100); - let await_pingresp_atomic = Arc::new(AtomicBool::new(false)); - - let read_network = NetworkReader { - run_signal: run_signal.clone(), - handler_helper: PhantomData, - handler: handler, - read_stream, - await_pingresp_atomic: await_pingresp_atomic.clone(), - state_handler: self.state_handler.clone(), - to_writer_s, - join_set: JoinSet::new(), - }; - - let write_network = NetworkWriter { - run_signal: run_signal.clone(), - write_stream, - keep_alive_interval: self.options.keep_alive_interval, - last_network_action: self.last_network_action, - await_pingresp_bool: await_pingresp_atomic.clone(), - await_pingresp_time: None, - perform_keep_alive: self.perform_keep_alive, - state_handler: self.state_handler.clone(), - to_writer_r: to_writer_r, - to_network_r: self.to_network_r.clone(), - }; - - Ok((read_network, write_network)) - } - None => Err(ConnectionError::NoNetwork), - } - } -} - -pub struct NetworkReader { - pub(crate) run_signal: Arc, - - pub(crate) handler_helper: PhantomData, - pub handler: H, - - pub(crate) read_stream: super::stream::read_half::ReadStream, - pub(crate) await_pingresp_atomic: Arc, - pub(crate) state_handler: Arc, - pub(crate) to_writer_s: Sender, - pub(crate) join_set: JoinSet>, -} - -impl NetworkReader -where - N: HandlerExt, - S: tokio::io::AsyncRead + Sized + Unpin + Send + 'static, -{ - /// Runs the read half of the mqtt connection. - /// Continuously loops until disconnect or error. - /// - /// # Return - /// - Ok(None) in the case that the write task requested shutdown. - /// - Ok(Some(reason)) in the case that this task initiates a shutdown. - /// - Err in the case of IO, or protocol errors. - pub async fn run(mut self) -> (Result, H) { - let ret = self.read().await; - self.run_signal.store(false, std::sync::atomic::Ordering::Release); - while let Some(_) = self.join_set.join_next().await { - () - } - (ret, self.handler) - } - async fn read(&mut self) -> Result { - while self.run_signal.load(std::sync::atomic::Ordering::Acquire) { - loop { - let packet = self.read_stream.read().await?; - match packet { - Packet::PingResp => { - N::call_handler(&mut self.handler, packet).await; - #[cfg(feature = "logs")] - if !self.await_pingresp_atomic.fetch_and(false, std::sync::atomic::Ordering::SeqCst) { - tracing::warn!("Received PingResp but did not expect it"); - } - #[cfg(not(feature = "logs"))] - self.await_pingresp_atomic.store(false, std::sync::atomic::Ordering::SeqCst); - } - Packet::Disconnect(_) => { - N::call_handler(&mut self.handler, packet).await; - return Ok(NetworkStatus::IncomingDisconnect); - } - Packet::ConnAck(conn_ack) => { - if let Some(retransmit_packets) = self.state_handler.handle_incoming_connack(&conn_ack)? { - for packet in retransmit_packets.into_iter() { - self.to_writer_s.send(packet).await?; - } - } - N::call_handler(&mut self.handler, Packet::ConnAck(conn_ack)).await; - } - packet => match self.state_handler.handle_incoming_packet(&packet)? { - (maybe_reply_packet, true) => { - N::call_handler_with_reply(self, packet, maybe_reply_packet).await?; - } - (Some(reply_packet), false) => { - self.to_writer_s.send(reply_packet).await?; - } - (None, false) => (), - }, - } - } - } - Ok(NetworkStatus::ShutdownSignal) - } -} - -pub struct NetworkWriter { - run_signal: Arc, - - write_stream: super::stream::write_half::WriteStream, - - keep_alive_interval: Duration, - - last_network_action: Instant, - await_pingresp_bool: Arc, - await_pingresp_time: Option, - perform_keep_alive: bool, - - state_handler: Arc, - - to_writer_r: Receiver, - to_network_r: Receiver, -} - -impl NetworkWriter -where - S: tokio::io::AsyncWriteExt + Sized + Unpin, -{ - /// Runs the read half of the mqtt connection. - /// Continuously loops until disconnect or error. - /// - /// # Return - /// - Ok(None) in the case that the read task requested shutdown - /// - Ok(Some(reason)) in the case that this task initiates a shutdown - /// - Err in the case of IO, or protocol errors. - pub async fn run(mut self) -> Result { - let ret = self.write().await; - self.run_signal.store(false, std::sync::atomic::Ordering::Release); - ret - } - async fn write(&mut self) -> Result { - while self.run_signal.load(std::sync::atomic::Ordering::Acquire) { - if self.await_pingresp_time.is_some() && !self.await_pingresp_bool.load(std::sync::atomic::Ordering::Acquire) { - self.await_pingresp_time = None; - } - - let sleep; - if let Some(instant) = &self.await_pingresp_time { - sleep = *instant + self.keep_alive_interval - Instant::now(); - } else { - sleep = self.last_network_action + self.keep_alive_interval - Instant::now(); - }; - tokio::select! { - outgoing = self.to_network_r.recv() => { - let packet = outgoing?; - self.write_stream.write(&packet).await?; - - let disconnect = packet.packet_type() == PacketType::Disconnect; - - self.state_handler.handle_outgoing_packet(packet)?; - self.last_network_action = Instant::now(); - - if disconnect{ - return Ok(NetworkStatus::OutgoingDisconnect); - } - }, - from_reader = self.to_writer_r.recv() => { - let packet = from_reader?; - self.write_stream.write(&packet).await?; - match packet { - foo @ (Packet::Publish(_) | Packet::Subscribe(_) | Packet::Unsubscribe(_) | Packet::Disconnect(_)) => { - self.state_handler.handle_outgoing_packet(foo)?; - }, - _ => (), - } - self.last_network_action = Instant::now(); - }, - _ = tokio::time::sleep(sleep), if self.await_pingresp_time.is_none() && self.perform_keep_alive => { - let packet = Packet::PingReq; - self.write_stream.write(&packet).await?; - self.await_pingresp_bool.store(true, std::sync::atomic::Ordering::SeqCst); - self.last_network_action = Instant::now(); - self.await_pingresp_time = Some(Instant::now()); - }, - _ = tokio::time::sleep(sleep), if self.await_pingresp_time.is_some() => { - self.await_pingresp_time = None; - if self.await_pingresp_bool.load(std::sync::atomic::Ordering::SeqCst){ - let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - self.write_stream.write(&Packet::Disconnect(disconnect)).await?; - return Ok(NetworkStatus::KeepAliveTimeout); - } - } - } - } - Ok(NetworkStatus::ShutdownSignal) - } -} diff --git a/mqrstt/src/tokio/stream/mod.rs b/mqrstt/src/tokio/stream.rs similarity index 60% rename from mqrstt/src/tokio/stream/mod.rs rename to mqrstt/src/tokio/stream.rs index ced07d4..081d720 100644 --- a/mqrstt/src/tokio/stream/mod.rs +++ b/mqrstt/src/tokio/stream.rs @@ -1,26 +1,12 @@ -pub mod read_half; -pub mod write_half; - -use std::io::{self, Error, ErrorKind}; - -use bytes::{Buf, BytesMut}; - -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::AsyncWriteExt; #[cfg(feature = "logs")] use tracing::trace; -use crate::packets::error::WriteError; use crate::packets::ConnAck; -use crate::packets::{ - error::ReadBytes, - ConnAckReasonCode, {FixedHeader, Packet}, -}; +use crate::packets::{ConnAckReasonCode, Packet}; use crate::{connect_options::ConnectOptions, error::ConnectionError}; -use self::read_half::ReadStream; -use self::write_half::WriteStream; - #[derive(Debug)] pub struct Stream { stream: S, @@ -30,14 +16,6 @@ impl Stream where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, { - pub fn split(self) -> (ReadStream, WriteStream) { - let Self { stream } = self; - - let (read_stream, write_stream) = tokio::io::split(stream); - - (ReadStream::new(read_stream), WriteStream::new(write_stream)) - } - pub async fn connect(options: &ConnectOptions, stream: S) -> Result<(Self, ConnAck), ConnectionError> { let mut s = Self { stream }; @@ -63,29 +41,6 @@ where Ok(Packet::async_read(&mut self.stream).await?) } - // pub async fn read_bytes(&mut self) -> io::Result { - // let read = self.stream.read(&mut self.const_buffer).await?; - // if read == 0 { - // Err(io::Error::new(io::ErrorKind::ConnectionReset, "Connection reset by peer")) - // } else { - // self.read_buffer.extend_from_slice(&self.const_buffer[0..read]); - // Ok(read) - // } - // } - - // /// Reads more than 'required' bytes to frame a packet into self.read buffer - // pub async fn read_required_bytes(&mut self, required: usize) -> io::Result { - // let mut total_read = 0; - - // loop { - // let read = self.read_bytes().await?; - // total_read += read; - // if total_read >= required { - // return Ok(total_read); - // } - // } - // } - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { match packet.async_write(&mut self.stream).await { Ok(_) => (), diff --git a/mqrstt/src/tokio/stream/read_half.rs b/mqrstt/src/tokio/stream/read_half.rs deleted file mode 100644 index 764802a..0000000 --- a/mqrstt/src/tokio/stream/read_half.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::io; - -use bytes::{Buf, BytesMut}; -use tokio::io::{AsyncReadExt, ReadHalf}; - -use crate::{ - error::ConnectionError, - packets::{error::ReadBytes, FixedHeader, Packet}, -}; - -#[cfg(feature = "logs")] -use tracing::trace; - -#[derive(Debug)] -pub struct ReadStream { - stream: ReadHalf, -} - -impl ReadStream -where - S: tokio::io::AsyncRead + Sized + Unpin, -{ - pub fn new(stream: ReadHalf) -> Self { - Self { stream } - } - - pub async fn read(&mut self) -> Result { - Ok(Packet::async_read(&mut self.stream).await?) - } -} diff --git a/mqrstt/src/tokio/stream/write_half.rs b/mqrstt/src/tokio/stream/write_half.rs deleted file mode 100644 index 6d161a0..0000000 --- a/mqrstt/src/tokio/stream/write_half.rs +++ /dev/null @@ -1,51 +0,0 @@ -use bytes::BytesMut; -use tokio::io::{AsyncWriteExt, WriteHalf}; - -use crate::{error::ConnectionError, packets::Packet}; - -#[cfg(feature = "logs")] -use tracing::trace; - -#[derive(Debug)] -pub struct WriteStream { - stream: WriteHalf, -} - -impl WriteStream { - pub fn new(stream: WriteHalf) -> Self { - Self { stream } - } -} - -impl WriteStream -where - S: tokio::io::AsyncWrite + Sized + Unpin, -{ - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - match packet.async_write(&mut self.stream).await { - Ok(_) => (), - Err(err) => { - return match err { - crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), - crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), - } - } - } - - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); - - self.stream.flush().await?; - Ok(()) - } - - pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { - for packet in packets { - let _ = packet.async_write(&mut self.stream).await; - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); - } - self.stream.flush().await?; - Ok(()) - } -} From 9c1cbc218280e733d29cd1495dcacef2d07edd68 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:39:41 +0100 Subject: [PATCH 25/45] Remove Concurrent event handler Removing concurrent event handler Renaming mutable event handler to normal event handler --- mqrstt/src/event_handlers.rs | 74 +++--------------------------------- mqrstt/src/smol/network.rs | 4 +- 2 files changed, 7 insertions(+), 71 deletions(-) diff --git a/mqrstt/src/event_handlers.rs b/mqrstt/src/event_handlers.rs index 0941095..2b786fe 100644 --- a/mqrstt/src/event_handlers.rs +++ b/mqrstt/src/event_handlers.rs @@ -1,55 +1,18 @@ -use std::sync::Arc; - use futures::Future; use crate::packets::Packet; - -/// Handlers are used to deal with packets before they are acknowledged to the broker. -/// This guarantees that the end user has handlded the packet. Additionally, handlers only deal with incoming packets. -/// -/// This handler can be used to handle message sequentialy or concurrently. -/// -/// To send messages look at [`crate::MqttClient`] -pub trait AsyncEventHandler { - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync; -} -impl AsyncEventHandler for &T -where - T: AsyncEventHandler, -{ - #[inline] - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync { - AsyncEventHandler::handle(*self, incoming_packet) - } -} -impl AsyncEventHandler for Arc -where - T: AsyncEventHandler, -{ - #[inline] - fn handle(&self, incoming_packet: Packet) -> impl Future + Send + Sync { - ::handle(&self, incoming_packet) - } -} -/// This is a simple no operation handler. -impl AsyncEventHandler for () { - fn handle(&self, _: Packet) -> impl Future + Send + Sync { - async {} - } -} - /// Handlers are used to deal with packets before they are acknowledged to the broker. /// This guarantees that the end user has handlded the packet. Additionally, handlers only deal with incoming packets. /// /// This handler can be used to handle message sequentialy. /// /// To send messages look at [`crate::MqttClient`] -pub trait AsyncEventHandlerMut { +pub trait AsyncEventHandler { fn handle(&mut self, incoming_packet: Packet) -> impl Future + Send + Sync; } /// This is a simple no operation handler. -impl AsyncEventHandlerMut for () { +impl AsyncEventHandler for () { fn handle(&mut self, _: Packet) -> impl Future + Send + Sync { async {} } @@ -70,14 +33,14 @@ pub mod example_handlers { use crate::{ packets::{self, Packet}, - AsyncEventHandler, AsyncEventHandlerMut, EventHandler, MqttClient, + AsyncEventHandler, EventHandler, MqttClient, }; /// Most basic no op handler /// This handler performs no operations on incoming messages. pub struct NOP {} - impl AsyncEventHandlerMut for NOP { + impl AsyncEventHandler for NOP { async fn handle(&mut self, _: Packet) {} } @@ -99,7 +62,7 @@ pub mod example_handlers { } } - impl AsyncEventHandlerMut for PingResp { + impl AsyncEventHandler for PingResp { async fn handle(&mut self, event: packets::Packet) -> () { use Packet::*; if event == PingResp { @@ -131,33 +94,6 @@ pub mod example_handlers { } impl AsyncEventHandler for PingPong { - async fn handle(&self, event: packets::Packet) -> () { - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - // let max_len = payload.len().min(10); - // let a = &payload[0..max_len]; - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); - // println!("Received publish payload: {}", a); - - if !p.retain { - self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - - // println!("DBG: \n {}", &Packet::Publish(p)); - } - } - } - Packet::ConnAck(_) => { - // println!("Connected!") - } - _ => (), - } - } - } - - impl AsyncEventHandlerMut for PingPong { async fn handle(&mut self, event: packets::Packet) -> () { match event { Packet::Publish(p) => { diff --git a/mqrstt/src/smol/network.rs b/mqrstt/src/smol/network.rs index 6497dcd..0d96c29 100644 --- a/mqrstt/src/smol/network.rs +++ b/mqrstt/src/smol/network.rs @@ -11,7 +11,7 @@ use crate::error::ConnectionError; use crate::packets::error::ReadBytes; use crate::packets::{Disconnect, DisconnectReasonCode, Packet, PacketType}; use crate::NetworkStatus; -use crate::{AsyncEventHandlerMut, StateHandler}; +use crate::{AsyncEventHandler, StateHandler}; use super::stream::Stream; @@ -61,7 +61,7 @@ impl Network { impl Network where - H: AsyncEventHandlerMut, + H: AsyncEventHandler, S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, { /// Initializes an MQTT connection with the provided configuration an stream From 822a666f9d157b5ced5d6d52bae7d7c24bfa397a Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:40:26 +0100 Subject: [PATCH 26/45] Return on err writing variable integer --- mqrstt/src/packets/primitive/variable_integer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mqrstt/src/packets/primitive/variable_integer.rs b/mqrstt/src/packets/primitive/variable_integer.rs index 276d613..9ee609f 100644 --- a/mqrstt/src/packets/primitive/variable_integer.rs +++ b/mqrstt/src/packets/primitive/variable_integer.rs @@ -147,7 +147,7 @@ impl VariableInteger for usize { break; } } - stream.write_all(&buf[0..length]).await; + stream.write_all(&buf[0..length]).await?; Ok(length) } } @@ -249,7 +249,7 @@ impl VariableInteger for u32 { break; } } - stream.write_all(&buf[0..length]).await; + stream.write_all(&buf[0..length]).await?; Ok(length) } } From fa087b7eff1cebaa235292abb35c3edfa2e0a86e Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:41:46 +0100 Subject: [PATCH 27/45] Removing benchmarks Benchmarks are being removed as they are not possible anymore with the public api. Additionally, they are not as up to date anymore --- mqrstt/Cargo.toml | 7 +- mqrstt/benches/bench_main.rs | 8 - mqrstt/benches/benchmarks/mod.rs | 159 ---------------- mqrstt/benches/benchmarks/tokio.rs | 285 ----------------------------- 4 files changed, 1 insertion(+), 458 deletions(-) delete mode 100644 mqrstt/benches/bench_main.rs delete mode 100644 mqrstt/benches/benchmarks/mod.rs delete mode 100644 mqrstt/benches/benchmarks/tokio.rs diff --git a/mqrstt/Cargo.toml b/mqrstt/Cargo.toml index 54c19f6..705ba9a 100644 --- a/mqrstt/Cargo.toml +++ b/mqrstt/Cargo.toml @@ -74,9 +74,4 @@ webpki = { version = "0.22.4" } async-rustls = { version = "0.4.1" } tokio-rustls = "0.24.1" rstest = "0.23.0" -rand = "0.8.5" - - -[[bench]] -name = "bench_main" -harness = false +rand = "0.8.5" \ No newline at end of file diff --git a/mqrstt/benches/bench_main.rs b/mqrstt/benches/bench_main.rs deleted file mode 100644 index 32d301e..0000000 --- a/mqrstt/benches/bench_main.rs +++ /dev/null @@ -1,8 +0,0 @@ -use criterion::criterion_main; - -mod benchmarks; - -criterion_main! { - benchmarks::tokio::tokio_concurrent, - benchmarks::tokio::tokio_sequential, -} diff --git a/mqrstt/benches/benchmarks/mod.rs b/mqrstt/benches/benchmarks/mod.rs deleted file mode 100644 index 18fcb67..0000000 --- a/mqrstt/benches/benchmarks/mod.rs +++ /dev/null @@ -1,159 +0,0 @@ -use bytes::{BufMut, Bytes, BytesMut}; -use mqrstt::packets::{Packet, Publish}; - -pub mod tokio; - -fn fill_stuff(buffer: &mut BytesMut, publ_count: usize, publ_size: usize) { - // empty_connect(buffer); - // for i in 0..publ_count { - // very_large_publish(i as u16, publ_size / 5).write(buffer).unwrap(); - // } - // empty_disconnect().write(buffer).unwrap(); -} - -fn empty_disconnect() -> Packet { - // let discon = Disconnect { - // reason_code: mqrstt::packets::reason_codes::DisconnectReasonCode::ServerBusy, - // properties: Default::default(), - // }; - - // Packet::Disconnect(discon) - todo!() -} - -fn empty_connect(buffer: &mut BytesMut) { - // let conn_ack = ConnAck{ - // connack_flags: ConnAckFlags::default(), - // reason_code: mqrstt::packets::reason_codes::ConnAckReasonCode::Success, - // connack_properties: Default::default(), - // }; - - // Packet::ConnAck(conn_ack) - // buffer.put_u8(0b0010_0000); // Connack flags - // buffer.put_u8(0x01); // Connack flags - // buffer.put_u8(0x00); // Reason code, - // buffer.put_u8(0x00); // empty properties - - buffer.put_u8(0x20); - buffer.put_u8(0x13); - buffer.put_u8(0x00); - buffer.put_u8(0x00); - buffer.put_u8(0x10); - buffer.put_u8(0x27); - buffer.put_u8(0x06); - buffer.put_u8(0x40); - buffer.put_u8(0x00); - buffer.put_u8(0x00); - buffer.put_u8(0x25); - buffer.put_u8(0x01); - buffer.put_u8(0x2a); - buffer.put_u8(0x01); - buffer.put_u8(0x29); - buffer.put_u8(0x01); - buffer.put_u8(0x22); - buffer.put_u8(0xff); - buffer.put_u8(0xff); - buffer.put_u8(0x28); - buffer.put_u8(0x01); -} - -/// Returns Publish Packet with 5x `repeat` as payload in bytes. -fn very_large_publish(id: u16, repeat: usize) -> Packet { - let publ = Publish { - dup: false, - qos: mqrstt::packets::QoS::ExactlyOnce, - retain: false, - topic: "BlaBla".into(), - packet_identifier: Some(id), - publish_properties: Default::default(), - payload: b"ping".repeat(repeat).to_vec(), - }; - - Packet::Publish(publ) -} - -mod test_handlers { - use std::{ - sync::{atomic::AtomicU16, Arc}, - time::Duration, - }; - - use bytes::Bytes; - use mqrstt::{ - packets::{self, Packet}, - AsyncEventHandler, AsyncEventHandlerMut, MqttClient, - }; - - pub struct PingPong { - pub client: MqttClient, - pub number: Arc, - } - - impl PingPong { - pub fn new(client: MqttClient) -> Self { - Self { - client, - number: Arc::new(AtomicU16::new(0)), - } - } - } - - impl AsyncEventHandler for PingPong { - async fn handle(&self, event: packets::Packet) -> () { - self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - let max_len = payload.len().min(10); - let _a = &payload[0..max_len]; - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); - } - } - } - Packet::ConnAck(_) => (), - _ => (), - } - } - } - - impl AsyncEventHandlerMut for PingPong { - async fn handle(&mut self, event: packets::Packet) -> () { - self.number.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - match event { - Packet::Publish(p) => { - if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { - let max_len = payload.len().min(10); - let _a = &payload[0..max_len]; - if payload.to_lowercase().contains("ping") { - self.client.publish(p.topic.clone(), p.qos, p.retain, Bytes::from_static(b"pong")).await.unwrap(); - } - } - } - Packet::ConnAck(_) => (), - _ => (), - } - } - } - - pub struct SimpleDelay { - delay: Duration, - } - - impl SimpleDelay { - pub fn new(delay: Duration) -> Self { - Self { delay } - } - } - - impl AsyncEventHandler for SimpleDelay { - fn handle(&self, _: Packet) -> impl futures::prelude::Future + Send + Sync { - tokio::time::sleep(self.delay) - } - } - impl AsyncEventHandlerMut for SimpleDelay { - fn handle(&mut self, _: Packet) -> impl futures::prelude::Future + Send + Sync { - tokio::time::sleep(self.delay) - } - } -} diff --git a/mqrstt/benches/benchmarks/tokio.rs b/mqrstt/benches/benchmarks/tokio.rs deleted file mode 100644 index 74ce0b9..0000000 --- a/mqrstt/benches/benchmarks/tokio.rs +++ /dev/null @@ -1,285 +0,0 @@ -use std::{hint::black_box, io::Write, net::SocketAddr, sync::Arc, time::Duration}; - -use bytes::BytesMut; -use criterion::{criterion_group, Criterion}; -use mqrstt::{ConnectOptions, NetworkBuilder, NetworkStatus}; -use tokio::net::TcpStream; - -use crate::benchmarks::test_handlers::{PingPong, SimpleDelay}; - -use super::fill_stuff; - -fn tokio_setup() -> (TcpStream, std::net::TcpStream, SocketAddr) { - let mut buffer = BytesMut::new(); - - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let tcp_stream = std::net::TcpStream::connect(addr).unwrap(); - - let (mut server, _addr) = listener.accept().unwrap(); - - fill_stuff(&mut buffer, 100, 5_000_000); - - server.write_all(&buffer.to_vec()).unwrap(); - - let tcp_stream = tokio::net::TcpStream::from_std(tcp_stream).unwrap(); - (tcp_stream, server, _addr) -} - -fn tokio_concurrent_benchmarks(c: &mut Criterion) { - let mut group = c.benchmark_group("Tokio concurrent read, write and handling"); - group.sample_size(30); - group.measurement_time(Duration::from_secs(120)); - - group.bench_function("tokio_bench_concurrent_read_write_and_handling_NOP", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut ()).await.unwrap(); - let (read, write) = network.split(()).unwrap(); - - let _network_box = black_box(network); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = tokio::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - let read_res = read_res.unwrap(); - assert_eq!(read_res, NetworkStatus::IncomingDisconnect); - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_and_handling_PingPong", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); - - let mut pingpong = Arc::new(PingPong::new(client.clone())); - - network.connect(tcp_stream, &mut pingpong).await.unwrap(); - let (read, write) = network.split(pingpong.clone()).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = futures::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - let read_res = read_res.unwrap(); - assert_eq!(read_res, NetworkStatus::IncomingDisconnect); - assert_eq!(102, pingpong.number.load(std::sync::atomic::Ordering::SeqCst)); - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _server_box = black_box(client.clone()); - let _server_box = black_box(server); - let _addr_box = black_box(addr); - let _network_box = black_box(network); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_and_handling_100ms_Delay", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_concurrent_network(); - - let _server_box = black_box(client); - - let mut handler = Arc::new(SimpleDelay::new(Duration::from_millis(100))); - - network.connect(tcp_stream, &mut handler).await.unwrap(); - let (read, write) = network.split(handler).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = tokio::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _network_box = black_box(network); - }) - }); - - group.bench_function("tokio_bench_concurrent_read_write", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut ()).await.unwrap(); - - let (read, write) = network.split(()).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = tokio::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_PingPong", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let mut pingpong = PingPong::new(client.clone()); - - let num_packets_received = pingpong.number.clone(); - - network.connect(tcp_stream, &mut pingpong).await.unwrap(); - let (read, write) = network.split(pingpong).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = futures::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - assert_eq!(102, num_packets_received.load(std::sync::atomic::Ordering::SeqCst)); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _server_box = black_box(client.clone()); - let _server_box = black_box(server); - let _addr_box = black_box(addr); - let _network_box = black_box(network); - }) - }); - group.bench_function("tokio_bench_concurrent_read_write_100ms_Delay", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let _server_box = black_box(client); - - let mut handler = SimpleDelay::new(Duration::from_millis(100)); - - network.connect(tcp_stream, &mut handler).await.unwrap(); - let (read, write) = network.split(handler).unwrap(); - - let read_handle = tokio::task::spawn(read.run()); - let write_handle = tokio::task::spawn(write.run()); - - let (read_res, write_res) = futures::join!(read_handle, write_handle); - assert!(read_res.is_ok()); - let (read_res, _handler) = read_res.unwrap(); - assert!(read_res.is_ok()); - assert_eq!(read_res.unwrap(), NetworkStatus::IncomingDisconnect); - - assert_eq!(write_res.unwrap().unwrap(), NetworkStatus::ShutdownSignal); - - let _network_box = black_box(network); - }) - }); -} - -fn tokio_sequential_benchmarks(c: &mut Criterion) { - let mut group = c.benchmark_group("Tokio sequential"); - group.sample_size(30); - group.measurement_time(Duration::from_secs(120)); - - group.bench_function("tokio_bench_sync_read_write", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut ()).await.unwrap(); - - let network_res = network.run(&mut ()).await; - - assert!(network_res.is_ok()); - let network_res = network_res.unwrap(); - assert_eq!(network_res, NetworkStatus::IncomingDisconnect); - }) - }); - group.bench_function("tokio_bench_sync_read_write_PingPong", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let mut pingpong = PingPong::new(client.clone()); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut pingpong).await.unwrap(); - - let network_res = network.run(&mut pingpong).await; - - assert!(network_res.is_ok()); - let network_res = network_res.unwrap(); - assert_eq!(network_res, NetworkStatus::IncomingDisconnect); - }) - }); - group.bench_function("tokio_bench_sync_read_write_100ms_Delay", |b| { - let runtime = tokio::runtime::Runtime::new().unwrap(); - b.to_async(runtime).iter_with_setup(tokio_setup, |(tcp_stream, server, addr)| async move { - let _server_box = black_box(server); - let _addr = black_box(addr); - - let options = ConnectOptions::new("test"); - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); - - let mut handler = SimpleDelay::new(Duration::from_millis(100)); - - let _server_box = black_box(client); - - network.connect(tcp_stream, &mut handler).await.unwrap(); - - let network_res = network.run(&mut handler).await; - - assert!(network_res.is_ok()); - let network_res = network_res.unwrap(); - assert_eq!(network_res, NetworkStatus::IncomingDisconnect); - }) - }); -} - -criterion_group!(tokio_concurrent, tokio_concurrent_benchmarks); -criterion_group!(tokio_sequential, tokio_sequential_benchmarks); From b615567ae904c5104cd4fa7eab1f6711bdd34261 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:46:29 +0100 Subject: [PATCH 28/45] Fix cargo lint in fuzzing --- fuzz/fuzz_targets/fuzz_target_1.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs index 1d3f536..9f16727 100644 --- a/fuzz/fuzz_targets/fuzz_target_1.rs +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -4,11 +4,9 @@ use libfuzzer_sys::fuzz_target; #[tokio::main(flavor = "current_thread")] async fn test(mut data: &[u8]) { - mqrstt::packets::Packet::async_read(&mut data).await; + let _ = mqrstt::packets::Packet::async_read(&mut data).await; } fuzz_target!(|data: &[u8]| { - // let mut packet = bytes::BytesMut::from(data); - // mqrstt::packets::Packet::read_from_buffer(&mut packet); test(data); }); From 53881bf2245d6427b5d0fdcc93efef7ccc8335ee Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:47:05 +0100 Subject: [PATCH 29/45] add rust-toolchain for fuzzing --- rust-toolchain.toml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 rust-toolchain.toml diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..d519a31 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly" +#channel = "1.82.0" \ No newline at end of file From 78def56ad49487d4b984289d6e578a71450bcb4a Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Tue, 26 Nov 2024 23:48:14 +0100 Subject: [PATCH 30/45] fuzz: expose reader functions for fuzzing --- mqrstt/src/packets/primitive/fixed_header.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mqrstt/src/packets/primitive/fixed_header.rs b/mqrstt/src/packets/primitive/fixed_header.rs index ee5a98b..4f7c626 100644 --- a/mqrstt/src/packets/primitive/fixed_header.rs +++ b/mqrstt/src/packets/primitive/fixed_header.rs @@ -30,7 +30,7 @@ pub(crate) struct FixedHeader { } impl FixedHeader { - pub fn read_fixed_header(mut header: Iter) -> Result<(Self, usize), ReadBytes> { + pub(crate) fn read_fixed_header(mut header: Iter) -> Result<(Self, usize), ReadBytes> { if header.len() < 2 { return Err(ReadBytes::InsufficientBytes(2 - header.len())); } @@ -46,7 +46,7 @@ impl FixedHeader { Ok((Self { packet_type, flags, remaining_length }, header_length)) } - pub async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> + pub(crate) async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> where S: tokio::io::AsyncRead + Unpin, { From 782dda6877304b42206f185eadc70296ec88f3bc Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 00:00:48 +0100 Subject: [PATCH 31/45] added tokio and smol tcp examples again --- Cargo.toml | 1 + examples/tcp/Cargo.toml | 19 ++++++++++++++++++ examples/tcp/src/smol.rs | 42 +++++++++++++++++++++++++++++++++++++++ examples/tcp/src/tokio.rs | 40 +++++++++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+) create mode 100644 examples/tcp/Cargo.toml create mode 100644 examples/tcp/src/smol.rs create mode 100644 examples/tcp/src/tokio.rs diff --git a/Cargo.toml b/Cargo.toml index a62cdfe..10dc1aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,4 +3,5 @@ members = [ "mqrstt", "fuzz", + "examples/tcp" ] \ No newline at end of file diff --git a/examples/tcp/Cargo.toml b/examples/tcp/Cargo.toml new file mode 100644 index 0000000..78834ec --- /dev/null +++ b/examples/tcp/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "tcp" +version = "0.1.0" +edition = "2021" + +[dependencies] +smol = { version = "2" } + +tokio = { version = "1", features = ["full"] } + +mqrstt = { path = "../../mqrstt", features = ["logs"] } + +[[bin]] +name = "tokio" +path = "src/tokio.rs" + +[[bin]] +name = "smol" +path = "src/smol.rs" diff --git a/examples/tcp/src/smol.rs b/examples/tcp/src/smol.rs new file mode 100644 index 0000000..592d880 --- /dev/null +++ b/examples/tcp/src/smol.rs @@ -0,0 +1,42 @@ +use mqrstt::AsyncEventHandler; + +pub struct Handler { + byte_count: u64, +} + +impl AsyncEventHandler for Handler { + fn handle(&mut self, incoming_packet: mqrstt::packets::Packet) -> impl std::future::Future + Send + Sync { + async move { + if let mqrstt::packets::Packet::Publish(publish) = incoming_packet { + self.byte_count += publish.payload.len() as u64; + } + } + } +} + +fn main() { + smol::block_on(async { + let hostname = "broker.emqx.io:1883"; + + let mut handler = Handler { byte_count: 0 }; + + let stream = smol::net::TcpStream::connect(hostname).await.unwrap(); + let (mut network, client) = mqrstt::NetworkBuilder::new_from_client_id("TestClientABCDEFG").smol_network(); + + network.connect(stream, &mut handler).await.unwrap(); + smol::Timer::after(std::time::Duration::from_secs(5)).await; + + client.subscribe("testtopic/#").await.unwrap(); + + smol::spawn(async move { + network.run(&mut handler).await.unwrap(); + + dbg!(handler.byte_count); + }) + .detach(); + + smol::Timer::after(std::time::Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); + smol::Timer::after(std::time::Duration::from_secs(1)).await; + }); +} diff --git a/examples/tcp/src/tokio.rs b/examples/tcp/src/tokio.rs new file mode 100644 index 0000000..1e9693b --- /dev/null +++ b/examples/tcp/src/tokio.rs @@ -0,0 +1,40 @@ +use mqrstt::AsyncEventHandler; + +pub struct Handler { + byte_count: u64, +} + +impl AsyncEventHandler for Handler { + fn handle(&mut self, incoming_packet: mqrstt::packets::Packet) -> impl std::future::Future + Send + Sync { + async move { + if let mqrstt::packets::Packet::Publish(publish) = incoming_packet { + self.byte_count += publish.payload.len() as u64; + } + } + } +} + +#[tokio::main] +async fn main() { + let hostname = "broker.emqx.io:1883"; + + let mut handler = Handler { byte_count: 0 }; + + let stream = tokio::net::TcpStream::connect(hostname).await.unwrap(); + let (mut network, client) = mqrstt::NetworkBuilder::new_from_client_id("TestClientABCDEFG").tokio_network(); + + network.connect(stream, &mut handler).await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + + client.subscribe("testtopic/#").await.unwrap(); + + tokio::spawn(async move { + network.run(&mut handler).await.unwrap(); + + dbg!(handler.byte_count); + }); + + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + client.disconnect().await.unwrap(); + tokio::time::sleep(std::time::Duration::from_secs(1)).await; +} From 1ce35da3c76e8505be8b15715d11a35753833b71 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 00:01:09 +0100 Subject: [PATCH 32/45] final things before 0.4.0 --- mqrstt/src/client.rs | 84 ++++++------------------- mqrstt/src/lib.rs | 44 ++----------- mqrstt/src/state_handler.rs | 12 ++-- mqrstt/src/tokio/network.rs | 119 +++++++++++++++++++++++++++++++----- mqrstt/src/tokio/stream.rs | 81 +++++++++++++----------- 5 files changed, 179 insertions(+), 161 deletions(-) diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index 0a96e44..50bcc8a 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -46,56 +46,6 @@ impl MqttClient { max_packet_size, } } - - /// This function is only here for you to use during testing of for example your handler - /// For a simple client look at [`MqttClient::test_client`] - #[cfg(feature = "test")] - pub fn test_custom_client(available_packet_ids_r: Receiver, to_network_s: Sender, max_packet_size: usize) -> Self { - Self { - available_packet_ids_r, - to_network_s, - max_packet_size, - } - } - - /// This function is only here for you to use during testing of for example your handler - /// For control over the input of this type look at [`MqttClient::test_custom_client`] - /// - /// The returned values should not be dropped otherwise the client won't be able to operate normally. - /// - /// # Example - /// ```ignore - /// let ( - /// client, // An instance of this client - /// ids, // Allows you to indicate which packet IDs have become available again. - /// network_receiver // Messages send through the `client` will be dispatched through this channel - /// ) = MqttClient::test_client(); - /// - /// // perform testing - /// - /// // Make sure to not drop these before the test is done! - /// std::hint::black_box((ids, network_receiver)); - /// ``` - #[cfg(feature = "test")] - pub fn test_client() -> (Self, crate::available_packet_ids::AvailablePacketIds, Receiver) { - use async_channel::unbounded; - - use crate::{available_packet_ids::AvailablePacketIds, util::constants::MAXIMUM_PACKET_SIZE}; - - let (available_packet_ids, available_packet_ids_r) = AvailablePacketIds::new(u16::MAX); - - let (s, r) = unbounded(); - - ( - Self { - available_packet_ids_r, - to_network_s: s, - max_packet_size: MAXIMUM_PACKET_SIZE as usize, - }, - available_packet_ids, - r, - ) - } } /// Async functions to perform MQTT operations @@ -106,7 +56,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -140,6 +90,8 @@ impl MqttClient { sub.validate(self.max_packet_size)?; self.to_network_s.send(Packet::Subscribe(sub)).await.map_err(|_| ClientError::NoNetworkChannel)?; + #[cfg(feature = "logs")] + info!("Send to network: Subscribe with ID {:?}", pkid); Ok(()) } @@ -150,7 +102,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("example_id").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -210,7 +162,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -264,7 +216,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -334,7 +286,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// // Unsubscribe from a single topic specified as a string: @@ -381,7 +333,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (_, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::UnsubscribeProperties; @@ -450,7 +402,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// mqtt_client.disconnect().await.unwrap(); @@ -476,7 +428,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; @@ -512,7 +464,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// use mqrstt::packets::QoS; /// use mqrstt::packets::{SubscriptionOptions, RetainHandling}; @@ -556,7 +508,7 @@ impl MqttClient { /// This function blocks until the packet is queued for transmission /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// use mqrstt::packets::QoS; /// use mqrstt::packets::{SubscribeProperties, SubscriptionOptions, RetainHandling}; @@ -616,7 +568,7 @@ impl MqttClient { /// This function blocks until the packet is queued for transmission /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -672,7 +624,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::QoS; @@ -742,7 +694,7 @@ impl MqttClient { /// /// # Examples /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// // Unsubscribe from a single topic specified as a string: @@ -790,7 +742,7 @@ impl MqttClient { /// # Examples /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::UnsubscribeProperties; @@ -851,7 +803,7 @@ impl MqttClient { /// # Example /// /// ``` - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// mqtt_client.disconnect_blocking().unwrap(); @@ -877,7 +829,7 @@ impl MqttClient { /// /// ``` /// - /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_sequential_network(); + /// # let (network, mqtt_client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream>::new_from_client_id("Example").smol_network(); /// # smol::block_on(async { /// /// use mqrstt::packets::DisconnectProperties; diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index ccd133c..9cfd25d 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -44,7 +44,7 @@ //! // To reconnect after a disconnect or error //! let (mut network, client) = NetworkBuilder //! ::new_from_client_id("mqrsttSmolExample") -//! .smol_sequential_network(); +//! .smol_network(); //! let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) //! .await //! .unwrap(); @@ -236,9 +236,9 @@ where /// ``` /// let (mut network, client) = mqrstt::NetworkBuilder::<(), smol::net::TcpStream> /// ::new_from_client_id("ExampleClient") - /// .smol_sequential_network(); + /// .smol_network(); /// ``` - pub fn smol_sequential_network(self) -> (smol::Network, MqttClient) { + pub fn smol_network(self) -> (smol::Network, MqttClient) { let (to_network_s, to_network_r) = async_channel::bounded(CHANNEL_SIZE); let (apkids, apkids_r) = available_packet_ids::AvailablePacketIds::new(self.options.send_maximum()); @@ -253,38 +253,6 @@ where } } -#[cfg(feature = "todo")] -/// Creates a new [`sync::Network`] and [`MqttClient`] that can be connected to a broker. -/// S should implement [`std::io::Read`] and [`std::io::Write`]. -/// Additionally, S should be made non_blocking otherwise it will not progress. -/// -/// # Example -/// -/// ``` -/// use mqrstt::ConnectOptions; -/// -/// let options = ConnectOptions::new("ExampleClient"); -/// let (network, client) = mqrstt::new_sync::(options); -/// ``` -pub fn new_sync(options: ConnectOptions) -> (sync::Network, MqttClient) -where - S: std::io::Read + std::io::Write + Sized + Unpin, -{ - use available_packet_ids::AvailablePacketIds; - - let (to_network_s, to_network_r) = async_channel::bounded(100); - - let (apkids, apkids_r) = AvailablePacketIds::new(options.send_maximum()); - - let max_packet_size = options.maximum_packet_size(); - - let client = MqttClient::new(apkids_r, to_network_s, max_packet_size); - - let network = sync::Network::new(options, to_network_r, apkids); - - (network, client) -} - #[cfg(test)] fn random_chars() -> String { rand::Rng::sample_iter(rand::thread_rng(), &rand::distributions::Alphanumeric).take(7).map(char::from).collect() @@ -310,7 +278,7 @@ mod smol_lib_test { let address = "broker.emqx.io"; let port = 1883; - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingpong = PingPong::new(client.clone()); @@ -347,7 +315,7 @@ mod smol_lib_test { let address = "broker.emqx.io"; let port = 1883; - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); @@ -393,7 +361,7 @@ mod smol_lib_test { let (n, _) = futures::join!( async { - let (mut network, client) = NetworkBuilder::new_from_options(options).smol_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).smol_network(); let stream = smol::net::TcpStream::connect((address, port)).await.unwrap(); let mut pingresp = crate::example_handlers::PingResp::new(client.clone()); network.connect(stream, &mut pingresp).await diff --git a/mqrstt/src/state_handler.rs b/mqrstt/src/state_handler.rs index d0eeeb9..ecf7fd6 100644 --- a/mqrstt/src/state_handler.rs +++ b/mqrstt/src/state_handler.rs @@ -1,7 +1,6 @@ use crate::available_packet_ids::AvailablePacketIds; use crate::connect_options::ConnectOptions; use crate::error::HandlerError; -use crate::packets::{PubRecReasonCode, PubAckReasonCode, ConnAckReasonCode}; use crate::packets::PubComp; use crate::packets::PubRec; use crate::packets::PubRel; @@ -12,6 +11,7 @@ use crate::packets::Subscribe; use crate::packets::UnsubAck; use crate::packets::Unsubscribe; use crate::packets::{ConnAck, Disconnect}; +use crate::packets::{ConnAckReasonCode, PubAckReasonCode, PubRecReasonCode}; use crate::packets::{Packet, PacketType}; use crate::packets::{PubAck, PubAckProperties}; use crate::state::State; @@ -188,6 +188,7 @@ impl StateHandler { _a => { #[cfg(test)] unreachable!("Was given unexpected packet {:?} ", _a); + #[cfg(not(test))] Ok(()) } } @@ -247,13 +248,8 @@ mod handler_tests { use crate::{ available_packet_ids::AvailablePacketIds, packets::{ - Packet, - PubComp, PubCompProperties, PubCompReasonCode, - PubRec, PubRecProperties, PubRecReasonCode, - PubRel, PubRelProperties, PubRelReasonCode, - QoS, - SubAck, SubAckProperties, SubAckReasonCode, - UnsubAck, UnsubAckProperties, UnsubAckReasonCode + Packet, PubComp, PubCompProperties, PubCompReasonCode, PubRec, PubRecProperties, PubRecReasonCode, PubRel, PubRelProperties, PubRelReasonCode, QoS, SubAck, SubAckProperties, + SubAckReasonCode, UnsubAck, UnsubAckProperties, UnsubAckReasonCode, }, tests::test_packets::{create_connack_packet, create_puback_packet, create_publish_packet, create_subscribe_packet, create_unsubscribe_packet}, ConnectOptions, StateHandler, diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index 84efd01..3caa0e4 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -13,7 +13,7 @@ use crate::packets::{Disconnect, Packet, PacketType}; use crate::{AsyncEventHandler, NetworkStatus, StateHandler}; -use super::stream::Stream; +use super::stream::StreamExt; /// [`Network`] reads and writes to the network based on tokios [`::tokio::io::AsyncReadExt`] [`::tokio::io::AsyncWriteExt`]. /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. @@ -21,7 +21,7 @@ use super::stream::Stream; /// (i.e. you need to reconnect after any expected or unexpected disconnect). pub struct Network { handler: PhantomData, - network: Option>, + network: Option, /// Options of the current mqtt connection options: ConnectOptions, @@ -55,8 +55,8 @@ where S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// Initializes an MQTT connection with the provided configuration an stream - pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { - let (mut network, conn_ack) = Stream::connect(&self.options, stream).await?; + pub async fn connect(&mut self, mut stream: S, handler: &mut H) -> Result<(), ConnectionError> { + let conn_ack = stream.connect(&self.options).await?; self.last_network_action = Instant::now(); if let Some(keep_alive_interval) = conn_ack.connack_properties.server_keep_alive { @@ -68,12 +68,12 @@ where let packets = self.state_handler.handle_incoming_connack(&conn_ack)?; handler.handle(Packet::ConnAck(conn_ack)).await; - if let Some(mut packets) = packets { - network.write_all(&mut packets).await?; + if let Some(packets) = packets { + stream.write_packets(&packets).await?; self.last_network_action = Instant::now(); } - self.network = Some(network); + self.network = Some(stream); Ok(()) } @@ -117,7 +117,6 @@ where } = self; let mut await_pingresp = None; - // let mut outgoing_packet_buffer = Vec::new(); loop { let sleep; @@ -129,7 +128,10 @@ where if let Some(stream) = network { tokio::select! { - res = stream.read() => { + res = stream.read_packet() => { + #[cfg(feature = "logs")] + tracing::trace!("Received incoming packet {:?}", &res); + let packet = res?; match packet{ Packet::PingResp => { @@ -145,12 +147,12 @@ where (maybe_reply_packet, true) => { handler.handle(packet).await; if let Some(reply_packet) = maybe_reply_packet { - stream.write(&reply_packet).await?; + stream.write_packet(&reply_packet).await?; *last_network_action = Instant::now(); } }, (Some(reply_packet), false) => { - stream.write(&reply_packet).await?; + stream.write_packet(&reply_packet).await?; *last_network_action = Instant::now(); }, (None, false) => (), @@ -159,8 +161,15 @@ where } }, outgoing = to_network_r.recv() => { + #[cfg(feature = "logs")] + tracing::trace!("Received outgoing item {:?}", &outgoing); + let packet = outgoing?; - stream.write(&packet).await?; + + #[cfg(feature = "logs")] + tracing::trace!("Sending packet {}", packet); + + stream.write_packet(&packet).await?; let disconnect = packet.packet_type() == PacketType::Disconnect; state_handler.handle_outgoing_packet(packet)?; @@ -173,13 +182,13 @@ where }, _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { let packet = Packet::PingReq; - stream.write(&packet).await?; + stream.write_packet(&packet).await?; *last_network_action = Instant::now(); await_pingresp = Some(Instant::now()); }, _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; - stream.write(&Packet::Disconnect(disconnect)).await?; + stream.write_packet(&Packet::Disconnect(disconnect)).await?; return Ok(NetworkStatus::KeepAliveTimeout); } } @@ -188,4 +197,86 @@ where } } } + + // async fn concurrent_tokio_select(&mut self, handler: &mut H) -> Result { + // let Network { + // network, + // options, + // last_network_action, + // perform_keep_alive, + // to_network_r, + // handler: _, + // state_handler, + // } = self; + + // let mut await_pingresp = None; + + // loop { + // let sleep; + // if let Some(instant) = await_pingresp { + // sleep = instant + options.get_keep_alive_interval() - Instant::now(); + // } else { + // sleep = *last_network_action + options.get_keep_alive_interval() - Instant::now(); + // } + + // if let Some(stream) = network { + // tokio::select! { + // res = stream.read_packet() => { + // let packet = res?; + // match packet{ + // Packet::PingResp => { + // handler.handle(packet).await; + // await_pingresp = None; + // }, + // Packet::Disconnect(_) => { + // handler.handle(packet).await; + // return Ok(NetworkStatus::IncomingDisconnect); + // } + // packet => { + // match state_handler.handle_incoming_packet(&packet)? { + // (maybe_reply_packet, true) => { + // handler.handle(packet).await; + // if let Some(reply_packet) = maybe_reply_packet { + // stream.write_packet(&reply_packet).await?; + // *last_network_action = Instant::now(); + // } + // }, + // (Some(reply_packet), false) => { + // stream.write_packet(&reply_packet).await?; + // *last_network_action = Instant::now(); + // }, + // (None, false) => (), + // } + // } + // } + // }, + // outgoing = to_network_r.recv() => { + // let packet = outgoing?; + // stream.write_packet(&packet).await?; + // let disconnect = packet.packet_type() == PacketType::Disconnect; + + // state_handler.handle_outgoing_packet(packet)?; + // *last_network_action = Instant::now(); + + // if disconnect{ + // return Ok(NetworkStatus::OutgoingDisconnect); + // } + // }, + // _ = tokio::time::sleep(sleep), if await_pingresp.is_none() && *perform_keep_alive => { + // let packet = Packet::PingReq; + // stream.write_packet(&packet).await?; + // *last_network_action = Instant::now(); + // await_pingresp = Some(Instant::now()); + // }, + // _ = tokio::time::sleep(sleep), if await_pingresp.is_some() => { + // let disconnect = Disconnect{ reason_code: DisconnectReasonCode::KeepAliveTimeout, properties: Default::default() }; + // stream.write_packet(&Packet::Disconnect(disconnect)).await?; + // return Ok(NetworkStatus::KeepAliveTimeout); + // } + // } + // } else { + // return Err(ConnectionError::NoNetwork); + // } + // } + // } } diff --git a/mqrstt/src/tokio/stream.rs b/mqrstt/src/tokio/stream.rs index 081d720..a5647e7 100644 --- a/mqrstt/src/tokio/stream.rs +++ b/mqrstt/src/tokio/stream.rs @@ -7,65 +7,76 @@ use crate::packets::ConnAck; use crate::packets::{ConnAckReasonCode, Packet}; use crate::{connect_options::ConnectOptions, error::ConnectionError}; -#[derive(Debug)] -pub struct Stream { - stream: S, +pub(crate) trait StreamExt { + fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future>; + fn read_packet(&mut self) -> impl std::future::Future>; + fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future>; + fn write_packets(&mut self, packets: &[Packet]) -> impl std::future::Future>; + fn flush_packets(&mut self) -> impl std::future::Future>; } -impl Stream +impl StreamExt for S where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, { - pub async fn connect(options: &ConnectOptions, stream: S) -> Result<(Self, ConnAck), ConnectionError> { - let mut s = Self { stream }; + fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future> { + async move { + let connect = options.create_connect_from_options(); - let connect = options.create_connect_from_options(); + self.write_packet(&connect).await?; - s.write(&connect).await?; - - let packet = Packet::async_read(&mut s.stream).await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - #[cfg(feature = "logs")] - trace!("Connected to server"); - Ok((s, con)) + let packet = Packet::async_read(self).await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + #[cfg(feature = "logs")] + trace!("Connected to server"); + Ok(con) + } else { + Err(ConnectionError::ConnectionRefused(con.reason_code)) + } } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) + Err(ConnectionError::NotConnAck(packet)) } - } else { - Err(ConnectionError::NotConnAck(packet)) } } - pub async fn read(&mut self) -> Result { - Ok(Packet::async_read(&mut self.stream).await?) + fn read_packet(&mut self) -> impl std::future::Future> { + async move { Ok(Packet::async_read(self).await?) } } - pub async fn write(&mut self, packet: &Packet) -> Result<(), ConnectionError> { - match packet.async_write(&mut self.stream).await { - Ok(_) => (), - Err(err) => { - return match err { - crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), - crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), + fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future> { + async move { + match packet.async_write(self).await { + Ok(_) => (), + Err(err) => { + return match err { + crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), + crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), + } } } - } - self.stream.flush().await?; - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); + #[cfg(feature = "logs")] + trace!("Sending packet {}", packet); + + self.flush().await?; + // self.flush_packets().await?; - Ok(()) + Ok(()) + } } - pub async fn write_all(&mut self, packets: &mut Vec) -> Result<(), ConnectionError> { + async fn write_packets(&mut self, packets: &[Packet]) -> Result<(), ConnectionError> { for packet in packets { - let _ = packet.async_write(&mut self.stream).await; + let _ = packet.async_write(self).await; #[cfg(feature = "logs")] trace!("Sending packet {}", packet); } - self.stream.flush().await?; + self.flush_packets().await?; Ok(()) } + + fn flush_packets(&mut self) -> impl std::future::Future> { + tokio::io::AsyncWriteExt::flush(self) + } } From 1c914c8777aba6e59ccfa795b06ded27dce2e207 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 00:42:24 +0100 Subject: [PATCH 33/45] fix up test cases --- mqrstt/src/client.rs | 19 +++------ mqrstt/src/event_handlers.rs | 11 ++---- mqrstt/src/lib.rs | 74 ++++++++++++++++++------------------ mqrstt/src/packets/mod.rs | 21 +++++++--- 4 files changed, 60 insertions(+), 65 deletions(-) diff --git a/mqrstt/src/client.rs b/mqrstt/src/client.rs index 50bcc8a..1e7c72d 100644 --- a/mqrstt/src/client.rs +++ b/mqrstt/src/client.rs @@ -7,16 +7,7 @@ use crate::{ error::ClientError, packets::{ mqtt_trait::PacketValidation, - DisconnectReasonCode, - Packet, - QoS, - // disconnect::{Disconnect, DisconnectProperties}, - // publish::{Publish, PublishProperties}, - // subscribe::{Subscribe, SubscribeProperties, Subscription}, - // unsubscribe::{Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, - {Disconnect, DisconnectProperties}, - {Publish, PublishProperties}, - {Subscribe, SubscribeProperties, SubscribeTopics}, + DisconnectReasonCode, Packet, QoS, {Disconnect, DisconnectProperties}, {Publish, PublishProperties}, {Subscribe, SubscribeProperties, SubscribeTopics}, {Unsubscribe, UnsubscribeProperties, UnsubscribeTopics}, }, }; @@ -26,7 +17,7 @@ use crate::{ /// /// This object can be obtained by calling the builder functions on [`crate::NetworkBuilder`] /// -/// This client should be used in combindation with a handler [`crate::AsyncEventHandler`] or [`crate::AsyncEventHandlerMut`] to handle incoming messages. +/// This client should be used in combination with a handler [`crate::AsyncEventHandler`] to receive and send messages. pub struct MqttClient { /// Provides this client with an available packet id or waits on it. available_packet_ids_r: Receiver, @@ -212,7 +203,7 @@ impl MqttClient { /// Creates a Publish packet with additional publish properties. /// The packet is then asynchronously transferred to the Network stack for transmission. /// - /// Can be called with any payload that can be converted into [`Bytes`] + /// Can be called with any payload that can be converted into [`Vec`] /// /// # Examples /// ``` @@ -563,7 +554,7 @@ impl MqttClient { /// Creates a Publish packet which is then transferred to the Network stack for transmission. /// - /// Can be called with any payload that can be converted into [`Bytes`] + /// Can be called with any payload that can be converted into [`Vec`] /// /// This function blocks until the packet is queued for transmission /// # Examples @@ -618,7 +609,7 @@ impl MqttClient { /// Creates a Publish packet with additional publish properties. /// The packet is then transferred to the Network stack for transmission. /// - /// Can be called with any payload that can be converted into [`Bytes`] + /// Can be called with any payload that can be converted into [`Vec`] /// /// This function blocks until the packet is queued for transmission /// diff --git a/mqrstt/src/event_handlers.rs b/mqrstt/src/event_handlers.rs index 2b786fe..555d04f 100644 --- a/mqrstt/src/event_handlers.rs +++ b/mqrstt/src/event_handlers.rs @@ -50,15 +50,12 @@ pub mod example_handlers { pub struct PingResp { pub client: MqttClient, - pub ping_resp_received: AtomicU16, + pub ping_resp_received: u32, } impl PingResp { pub fn new(client: MqttClient) -> Self { - Self { - client, - ping_resp_received: AtomicU16::new(0), - } + Self { client, ping_resp_received: 0 } } } @@ -66,7 +63,7 @@ pub mod example_handlers { async fn handle(&mut self, event: packets::Packet) -> () { use Packet::*; if event == PingResp { - self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.ping_resp_received += 1; } println!("Received packet: {}", event); } @@ -76,7 +73,7 @@ pub mod example_handlers { fn handle(&mut self, event: Packet) { use Packet::*; if event == PingResp { - self.ping_resp_received.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + self.ping_resp_received += 1; } println!("Received packet: {}", event); } diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index 9cfd25d..c131ec6 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -19,7 +19,6 @@ //! //! Notes: //! ---------------------------- -//! - While the handler is processing a message the stream blocks. To prevent this, spawn a task in the handler or use [tokio::ConcurrentHandler]. //! - Handlers only get incoming packets //! - Create a new connection when an error or disconnect is encountered //! @@ -83,7 +82,7 @@ //! async fn main() { //! let (mut network, client) = NetworkBuilder //! ::new_from_client_id("TokioTcpPingPongExample") -//! .tokio_sequential_network(); +//! .tokio_network(); //! //! // Construct a no op handler //! let mut nop = NOP{}; @@ -206,7 +205,7 @@ where /// let options = ConnectOptions::new("ExampleClient"); /// let (mut network, client) = mqrstt::NetworkBuilder::<(), tokio::net::TcpStream> /// ::new_from_options(options) - /// .tokio_sequential_network(); + /// .tokio_network(); /// ``` pub fn tokio_network(self) -> (tokio::Network, MqttClient) where @@ -266,12 +265,12 @@ mod smol_lib_test { use rand::Rng; - use crate::{example_handlers::PingPong, packets::QoS, ConnectOptions, NetworkBuilder}; + use crate::{example_handlers::PingPong, packets::QoS, random_chars, ConnectOptions, NetworkBuilder}; #[test] fn test_smol_tcp() { smol::block_on(async { - let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); + let mut client_id: String = random_chars(); client_id += "_SmolTcpPingPong"; let options = ConnectOptions::new(client_id); @@ -339,7 +338,7 @@ mod smol_lib_test { ); assert!(n.is_ok()); let pingresp = n.unwrap(); - assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); + assert_eq!(2, pingresp.ping_resp_received); }); } @@ -385,45 +384,45 @@ mod smol_lib_test { #[cfg(feature = "tokio")] #[cfg(test)] mod tokio_lib_test { - use crate::example_handlers::PingPong; - - use crate::packets::QoS; + use crate::example_handlers::PingResp; + use crate::random_chars; + use crate::ConnectOptions; - use std::{sync::Arc, time::Duration}; - - // #[tokio::test] - // async fn test_tokio_ping_req() { - // let mut client_id: String = rand::thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(7).map(char::from).collect(); - // client_id += "_TokioTcppingrespTest"; - // let mut options = ConnectOptions::new(client_id); - // let keep_alive_interval = 5; - // options.set_keep_alive_interval(Duration::from_secs(keep_alive_interval)); + use std::time::Duration; - // let wait_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; + #[tokio::test] + async fn test_tokio_ping_req() { + let mut client_id: String = random_chars(); + client_id += "_TokioTcppingrespTest"; + let mut options = ConnectOptions::new(client_id); + let keep_alive_interval = 5; + options.set_keep_alive_interval(Duration::from_secs(keep_alive_interval)); - // let (mut network, client) = new_tokio(options); + let wait_duration = options.get_keep_alive_interval() * 2 + options.get_keep_alive_interval() / 2; - // let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + let (mut network, client) = crate::NetworkBuilder::new_from_options(options).tokio_network(); - // let pingresp = Arc::new(crate::test_handlers::PingResp::new(client.clone())); + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); - // network.connect(stream, &mut pingresp).await.unwrap(); + let mut pingresp = PingResp::new(client.clone()); - // let (read, write) = network.split(pingresp.clone()).unwrap(); + network.connect(stream, &mut pingresp).await.unwrap(); - // let read_handle = tokio::task::spawn(read.run()); - // let write_handle = tokio::task::spawn(write.run()); + let network_handle = tokio::task::spawn(async move { + network.run(&mut pingresp).await; + pingresp + }); - // tokio::time::sleep(wait_duration).await; - // client.disconnect().await.unwrap(); + tokio::time::sleep(wait_duration).await; + client.disconnect().await.unwrap(); - // tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_secs(1)).await; - // let (read_result, write_result) = tokio::join!(read_handle, write_handle); - // let (read_result, write_result) = (read_result.unwrap(), write_result.unwrap()); - // assert!(write_result.is_ok()); - // assert_eq!(2, pingresp.ping_resp_received.load(std::sync::atomic::Ordering::Acquire)); - // } + let result = network_handle.await; + assert!(result.is_ok()); + let result = result.unwrap(); + assert_eq!(2, result.ping_resp_received); + } #[cfg(all(feature = "tokio", target_family = "windows"))] #[tokio::test] @@ -435,11 +434,11 @@ mod tokio_lib_test { let address = ("127.0.0.1", 2000); let client_id: String = crate::random_chars() + "_TokioTcppingrespTest"; - let options = ConnectOptions::new(client_id); + let options = crate::ConnectOptions::new(client_id); let (n, _) = tokio::join!( async move { - let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_sequential_network(); + let (mut network, client) = NetworkBuilder::new_from_options(options).tokio_network(); let stream = tokio::net::TcpStream::connect(address).await.unwrap(); @@ -456,8 +455,7 @@ mod tokio_lib_test { ); if let ConnectionError::Io(err) = n.unwrap_err() { - assert_eq!(ErrorKind::ConnectionReset, err.kind()); - assert_eq!("Connection reset by peer".to_string(), err.to_string()); + assert_eq!(ErrorKind::UnexpectedEof, err.kind()); } else { panic!(); } diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 3846052..73a477f 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -451,14 +451,23 @@ mod tests { use crate::packets::Packet; - use crate::tests::test_packets::{create_empty_publish_packet, disconnect_case, ping_req_case, ping_resp_case, publish_case, pubrel_case, pubrel_smallest_case}; + use crate::tests::test_packets::*; #[rstest::rstest] - // #[case(disconnect_case())] - // #[case(ping_req_case())] - // #[case(ping_resp_case())] - // #[case(publish_case())] - // #[case(pubrel_case())] + #[case(ping_req_case().1)] + #[case(ping_resp_case().1)] + #[case(connack_case().1)] + #[case(create_subscribe_packet(1))] + #[case(create_subscribe_packet(65335))] + #[case(create_puback_packet(1))] + #[case(create_puback_packet(65335))] + #[case(create_disconnect_packet())] + #[case(create_connack_packet(true))] + #[case(create_connack_packet(false))] + #[case(publish_packet_1())] + #[case(publish_packet_2())] + #[case(publish_packet_3())] + #[case(publish_packet_4())] #[case(create_empty_publish_packet())] fn test_write_read_write_read_cases(#[case] packet: Packet) { let mut buffer = BytesMut::new(); From 4fc49246ad6eab9e4e44b906b76be17ba420c8d9 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 00:43:57 +0100 Subject: [PATCH 34/45] bump version number to v0.4.0 --- mqrstt/Cargo.toml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mqrstt/Cargo.toml b/mqrstt/Cargo.toml index 705ba9a..31d00c4 100644 --- a/mqrstt/Cargo.toml +++ b/mqrstt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mqrstt" -version = "0.3.0" +version = "0.4.0" homepage = "https://github.com/GunnarMorrigan/mqrstt" repository = "https://github.com/GunnarMorrigan/mqrstt" documentation = "https://docs.rs/mqrstt" @@ -10,7 +10,6 @@ edition = "2021" license = "MPL-2.0" keywords = ["MQTT", "IoT", "MQTTv5", "messaging", "client"] description = "Pure rust MQTTv5 client implementation Smol and Tokio" - rust-version = "1.75" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -53,8 +52,6 @@ tokio = { version = "1", features = [ smol = { version = "2", optional = true } [dev-dependencies] -criterion = { version = "0.5.1", features = ["async_tokio"] } - tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } smol = { version = "2.0.0" } From 1040da6708b65e3e153b56e5f703be29df7067ee Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:22:36 +0100 Subject: [PATCH 35/45] linter --- .github/workflows/rust.yml | 2 +- README.md | 6 +- mqrstt/src/error.rs | 10 +- mqrstt/src/event_handlers.rs | 8 +- mqrstt/src/lib.rs | 7 +- mqrstt/src/packets/auth/mod.rs | 11 +- mqrstt/src/packets/connack/mod.rs | 74 +++++----- mqrstt/src/packets/connect/connect_flags.rs | 20 ++- mqrstt/src/packets/connect/last_will.rs | 12 +- .../src/packets/macros/properties_macros.rs | 40 +----- .../src/packets/macros/reason_code_macros.rs | 2 +- mqrstt/src/packets/mod.rs | 2 +- mqrstt/src/packets/mqtt_trait/mod.rs | 2 +- .../src/packets/mqtt_trait/primitive_impl.rs | 98 ++++++------- mqrstt/src/packets/primitive/property_type.rs | 10 +- mqrstt/src/packets/primitive/qos.rs | 10 +- .../src/packets/primitive/variable_integer.rs | 132 ++++++++---------- mqrstt/src/packets/pubcomp/mod.rs | 94 ++++++------- mqrstt/src/packets/publish/mod.rs | 76 +++++----- mqrstt/src/packets/pubrec/mod.rs | 48 +++---- mqrstt/src/packets/suback/mod.rs | 66 ++++----- mqrstt/src/packets/subscribe/mod.rs | 104 +++++++------- mqrstt/src/packets/unsuback/mod.rs | 46 +++--- mqrstt/src/packets/unsubscribe/mod.rs | 73 +++++----- mqrstt/src/state.rs | 6 +- mqrstt/src/tokio/network.rs | 9 +- mqrstt/src/tokio/stream.rs | 58 ++++---- 27 files changed, 463 insertions(+), 563 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4ce4a31..67e83c0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,7 +34,7 @@ jobs: # run clippy to verify we have no warnings - run: cargo fetch - name: cargo clippy - run: cargo clippy --all-targets --all-features -- -D warnings + run: cargo clippy -p mqrstt test: name: Test diff --git a/README.md b/README.md index 58dff00..7e112dc 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ pub struct PingPong { } impl AsyncEventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { @@ -132,7 +132,7 @@ pub struct PingPong { } impl AsyncEventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { @@ -212,7 +212,7 @@ pub struct PingPong { impl EventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - fn handle(&mut self, event: packets::Packet) -> () { + fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { diff --git a/mqrstt/src/error.rs b/mqrstt/src/error.rs index 7651ceb..b439912 100644 --- a/mqrstt/src/error.rs +++ b/mqrstt/src/error.rs @@ -69,8 +69,8 @@ pub enum HandlerError { #[error("The incoming channel between network and handler is closed")] IncomingNetworkChannelClosed, - #[error("The outgoing channel between handler and network is closed: {0}")] - OutgoingNetworkChannelClosed(#[from] SendError), + #[error("The outgoing channel between handler and network is closed")] + OutgoingNetworkChannelClosed, #[error("Channel between client and handler closed")] ClientChannelClosed, @@ -88,6 +88,12 @@ pub enum HandlerError { UnexpectedPacket(PacketType), } +impl From> for HandlerError { + fn from(_: SendError) -> Self { + HandlerError::OutgoingNetworkChannelClosed + } +} + /// Errors producable by the [`crate::MqttClient`] #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] pub enum ClientError { diff --git a/mqrstt/src/event_handlers.rs b/mqrstt/src/event_handlers.rs index 555d04f..b69fca6 100644 --- a/mqrstt/src/event_handlers.rs +++ b/mqrstt/src/event_handlers.rs @@ -13,9 +13,7 @@ pub trait AsyncEventHandler { /// This is a simple no operation handler. impl AsyncEventHandler for () { - fn handle(&mut self, _: Packet) -> impl Future + Send + Sync { - async {} - } + async fn handle(&mut self, _: Packet) {} } pub trait EventHandler { @@ -60,7 +58,7 @@ pub mod example_handlers { } impl AsyncEventHandler for PingResp { - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { use Packet::*; if event == PingResp { self.ping_resp_received += 1; @@ -91,7 +89,7 @@ pub mod example_handlers { } impl AsyncEventHandler for PingPong { - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index c131ec6..e4d0477 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -123,7 +123,7 @@ pub mod smol; /// Contains the reader and writer parts for the tokio runtime. /// /// Module [`crate::tokio`] contains both a synchronized and concurrent approach to call the users `Handler`. -#[cfg(any(feature = "tokio"))] +#[cfg(feature = "tokio")] pub mod tokio; /// Error types that the user can see during operation of the client. @@ -342,7 +342,7 @@ mod smol_lib_test { }); } - #[cfg(all(target_family = "windows"))] + #[cfg(target_family = "windows")] #[test] fn test_close_write_tcp_stream_smol() { use crate::error::ConnectionError; @@ -409,7 +409,8 @@ mod tokio_lib_test { network.connect(stream, &mut pingresp).await.unwrap(); let network_handle = tokio::task::spawn(async move { - network.run(&mut pingresp).await; + let result = network.run(&mut pingresp).await; + // check result and or restart the connection pingresp }); diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs index 74a2107..e16fdda 100644 --- a/mqrstt/src/packets/auth/mod.rs +++ b/mqrstt/src/packets/auth/mod.rs @@ -1,5 +1,4 @@ mod properties; -use std::future::Future; pub use properties::AuthProperties; mod reason_code; @@ -48,12 +47,10 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for Auth where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl Future> { - async move { - let reason_code_writen = self.reason_code.async_write(stream).await?; - let properties_writen = self.properties.async_write(stream).await?; - Ok(reason_code_writen + properties_writen) - } + async fn async_write(&self, stream: &mut S) -> Result { + let reason_code_writen = self.reason_code.async_write(stream).await?; + let properties_writen = self.properties.async_write(stream).await?; + Ok(reason_code_writen + properties_writen) } } diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 2ff8650..94da71f 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -53,21 +53,19 @@ impl PacketAsyncRead for ConnAck where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, _: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; - let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; - let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; - - Ok(( - Self { - connack_flags, - reason_code, - connack_properties, - }, - read_bytes + reason_code_read_bytes + connack_properties_read_bytes, - )) - } + async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; + let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; + let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; + + Ok(( + Self { + connack_flags, + reason_code, + connack_properties, + }, + read_bytes + reason_code_read_bytes + connack_properties_read_bytes, + )) } } @@ -85,15 +83,13 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for ConnAck where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - use crate::packets::mqtt_trait::MqttAsyncWrite; - let connack_flags_writen = self.connack_flags.async_write(stream).await?; - let reason_code_writen = self.reason_code.async_write(stream).await?; - let connack_properties_writen = self.connack_properties.async_write(stream).await?; - - Ok(connack_flags_writen + reason_code_writen + connack_properties_writen) - } + async fn async_write(&self, stream: &mut S) -> Result { + use crate::packets::mqtt_trait::MqttAsyncWrite; + let connack_flags_writen = self.connack_flags.async_write(stream).await?; + let reason_code_writen = self.reason_code.async_write(stream).await?; + let connack_properties_writen = self.connack_properties.async_write(stream).await?; + + Ok(connack_flags_writen + reason_code_writen + connack_properties_writen) } } @@ -114,16 +110,14 @@ impl MqttAsyncRead for ConnAckFlags where S: tokio::io::AsyncRead + Unpin, { - fn async_read(stream: &mut S) -> impl std::future::Future> { - async move { - let byte = stream.read_u8().await?; - Ok(( - Self { - session_present: (byte & 0b00000001) == 0b00000001, - }, - 1, - )) - } + async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let byte = stream.read_u8().await?; + Ok(( + Self { + session_present: (byte & 0b00000001) == 0b00000001, + }, + 1, + )) } } @@ -154,14 +148,12 @@ impl crate::packets::mqtt_trait::MqttAsyncWrite for ConnAckFlags where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - use tokio::io::AsyncWriteExt; - let byte = self.session_present as u8; + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let byte = self.session_present as u8; - stream.write_u8(byte).await?; - Ok(1) - } + stream.write_u8(byte).await?; + Ok(1) } } diff --git a/mqrstt/src/packets/connect/connect_flags.rs b/mqrstt/src/packets/connect/connect_flags.rs index c2a576e..129a132 100644 --- a/mqrstt/src/packets/connect/connect_flags.rs +++ b/mqrstt/src/packets/connect/connect_flags.rs @@ -77,11 +77,9 @@ impl MqttAsyncRead for ConnectFlags where S: tokio::io::AsyncRead + Unpin, { - fn async_read(stream: &mut S) -> impl std::future::Future> { - async move { - let byte = stream.read_u8().await?; - Ok((ConnectFlags::from_u8(byte)?, 1)) - } + async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let byte = stream.read_u8().await?; + Ok((ConnectFlags::from_u8(byte)?, 1)) } } @@ -96,13 +94,11 @@ impl MqttAsyncWrite for ConnectFlags where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - use tokio::io::AsyncWriteExt; - let byte = self.into_u8()?; - stream.write_u8(byte).await?; + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let byte = self.into_u8()?; + stream.write_u8(byte).await?; - Ok(1) - } + Ok(1) } } diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs index b93e88e..bc076bc 100644 --- a/mqrstt/src/packets/connect/last_will.rs +++ b/mqrstt/src/packets/connect/last_will.rs @@ -82,14 +82,12 @@ impl MqttAsyncWrite for LastWill where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let properties_written = self.last_will_properties.async_write(stream).await?; - let topic_written = self.topic.async_write(stream).await?; - let payload_written = self.payload.async_write(stream).await?; + async fn async_write(&self, stream: &mut S) -> Result { + let properties_written = self.last_will_properties.async_write(stream).await?; + let topic_written = self.topic.async_write(stream).await?; + let payload_written = self.payload.async_write(stream).await?; - Ok(properties_written + topic_written + payload_written) - } + Ok(properties_written + topic_written + payload_written) } } diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index d845d9f..b71e149 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -34,40 +34,14 @@ macro_rules! define_properties { } impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + Unpin { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let mut bytes_writen = 0; - $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; - $( - $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); - )* - - Ok(bytes_writen) - } - - // let (len, length_variable_integer) = ::read_async_variable_integer(stream).await?; - // if len == 0 { - // return Ok((Self::default(), length_variable_integer)); - // } - - // let mut properties = $name::default(); - - // let mut read_property_bytes = 0; - // loop { - // let (prop, read_bytes) = crate::packets::PropertyType::async_read(stream).await?; - // read_property_bytes += read_bytes; - // match prop { - // $( - // $crate::packets::macros::properties_read_match_branch_name!($prop_variant) => $crate::packets::macros::properties_read_match_branch_body!(stream, properties, read_property_bytes, PropertyType::$prop_variant), - // )* - // e => return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), - // } - // if read_property_bytes == len { - // break; - // } - // } + async fn async_write(&self, stream: &mut S) -> Result { + let mut bytes_writen = 0; + $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; + $( + $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); + )* - // Ok((properties, length_variable_integer + read_property_bytes)) + Ok(bytes_writen) } } diff --git a/mqrstt/src/packets/macros/reason_code_macros.rs b/mqrstt/src/packets/macros/reason_code_macros.rs index 8e1fa29..68ec2e4 100644 --- a/mqrstt/src/packets/macros/reason_code_macros.rs +++ b/mqrstt/src/packets/macros/reason_code_macros.rs @@ -14,7 +14,7 @@ macro_rules! reason_code { } -> ()) } - pub(crate) fn to_u8(&self) -> u8 { + pub(crate) fn to_u8(self) -> u8 { $crate::packets::macros::reason_code_match_write!(@ $name, self, { $($code,)* } -> ()) diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 73a477f..2eb135a 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -347,7 +347,7 @@ impl Packet { #[cfg(feature = "logs")] tracing::trace!("Read packet header: {:?}", header); - Ok(Packet::async_read_packet(header, stream).await?) + Packet::async_read_packet(header, stream).await } pub fn read(buffer: &mut BytesMut) -> Result> { diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs index 50ce760..f666753 100644 --- a/mqrstt/src/packets/mqtt_trait/mod.rs +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -43,7 +43,7 @@ pub trait MqttWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; } -impl<'a, T> MqttWrite for &'a T +impl MqttWrite for &T where T: MqttWrite, { diff --git a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs index 8ff217a..b0612cb 100644 --- a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs +++ b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs @@ -43,13 +43,11 @@ impl MqttAsyncWrite for Box where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_bytes()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) } } @@ -73,13 +71,11 @@ impl MqttAsyncWrite for &str where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_bytes()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) } } @@ -131,13 +127,11 @@ impl MqttAsyncWrite for String where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_bytes()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) } } @@ -189,13 +183,11 @@ impl MqttAsyncWrite for Bytes where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_ref()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_ref()).await?; + Ok(2 + self.len()) } } @@ -234,13 +226,11 @@ impl MqttAsyncWrite for Vec where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self).await?; + Ok(2 + self.len()) } } impl WireLength for Vec { @@ -304,15 +294,13 @@ impl MqttAsyncWrite for bool where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - if *self { - stream.write_all(&[1]).await?; - } else { - stream.write_all(&[0]).await?; - } - Ok(1) + async fn async_write(&self, stream: &mut S) -> Result { + if *self { + stream.write_all(&[1]).await?; + } else { + stream.write_all(&[0]).await?; } + Ok(1) } } impl MqttRead for u8 { @@ -336,11 +324,9 @@ impl MqttAsyncWrite for u8 where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - stream.write_all(self.to_be_bytes().as_slice()).await?; - Ok(1) - } + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(1) } } @@ -372,11 +358,9 @@ impl MqttAsyncWrite for u16 where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - stream.write_all(self.to_be_bytes().as_slice()).await?; - Ok(2) - } + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(2) } } @@ -407,10 +391,8 @@ impl MqttAsyncWrite for u32 where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - stream.write_all(self.to_be_bytes().as_slice()).await?; - Ok(4) - } + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(4) } } diff --git a/mqrstt/src/packets/primitive/property_type.rs b/mqrstt/src/packets/primitive/property_type.rs index cce56e1..a1ecf70 100644 --- a/mqrstt/src/packets/primitive/property_type.rs +++ b/mqrstt/src/packets/primitive/property_type.rs @@ -148,11 +148,9 @@ impl MqttAsyncWrite for PropertyType where S: tokio::io::AsyncWrite + std::marker::Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let buf: [u8; 1] = [u8::from(self)]; - stream.write_all(&buf).await?; - Ok(1) - } + async fn async_write(&self, stream: &mut S) -> Result { + let buf: [u8; 1] = [u8::from(self)]; + stream.write_all(&buf).await?; + Ok(1) } } diff --git a/mqrstt/src/packets/primitive/qos.rs b/mqrstt/src/packets/primitive/qos.rs index 572a6a5..06be58e 100644 --- a/mqrstt/src/packets/primitive/qos.rs +++ b/mqrstt/src/packets/primitive/qos.rs @@ -76,11 +76,9 @@ impl MqttAsyncWrite for QoS where S: tokio::io::AsyncWrite + std::marker::Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let buf: [u8; 1] = [self.into_u8()]; - stream.write_all(&buf).await?; - Ok(1) - } + async fn async_write(&self, stream: &mut S) -> Result { + let buf: [u8; 1] = [self.into_u8()]; + stream.write_all(&buf).await?; + Ok(1) } } diff --git a/mqrstt/src/packets/primitive/variable_integer.rs b/mqrstt/src/packets/primitive/variable_integer.rs index 9ee609f..1841ead 100644 --- a/mqrstt/src/packets/primitive/variable_integer.rs +++ b/mqrstt/src/packets/primitive/variable_integer.rs @@ -66,7 +66,7 @@ impl VariableInteger for usize { fn write_variable_integer(&self, buf: &mut BytesMut) -> Result { if *self > 268_435_455 { - return Err(SerializeError::VariableIntegerOverflow(*self as usize)); + return Err(SerializeError::VariableIntegerOverflow(*self)); } let mut write = *self; @@ -82,7 +82,7 @@ impl VariableInteger for usize { return Ok(i + 1); } } - Err(SerializeError::VariableIntegerOverflow(*self as usize)) + Err(SerializeError::VariableIntegerOverflow(*self)) } fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { @@ -105,51 +105,47 @@ impl VariableInteger for usize { Err(DeserializeError::MalformedPacket) } - fn read_async_variable_integer(stream: &mut S) -> impl Future> { - async move { - let mut integer = 0; - let mut length = 0; + async fn read_async_variable_integer(stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut integer = 0; + let mut length = 0; - for i in 0..4 { - let byte = stream.read_u8().await?; - length += 1; + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; - integer += (byte as usize & 0x7f) << (7 * i); + integer += (byte as usize & 0x7f) << (7 * i); - if (byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); } - Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } - fn write_async_variable_integer(&self, stream: &mut S) -> impl Future> { - async move { - let mut buf = [0u8; 4]; + async fn write_async_variable_integer(&self, stream: &mut S) -> Result { + let mut buf = [0u8; 4]; - if *self > 268_435_455 { - return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); - } + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self))); + } - let mut write = *self; - let mut length = 1; - - for i in 0..4 { - let mut byte = (write % 128) as u8; - write /= 128; - if write > 0 { - byte |= 128; - } - buf[i] = byte; - if write == 0 { - length = i + 1; - break; - } + let mut write = *self; + let mut length = 1; + + for (i, item) in buf.iter_mut().enumerate() { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + *item = byte; + if write == 0 { + length = i + 1; + break; } - stream.write_all(&buf[0..length]).await?; - Ok(length) } + stream.write_all(&buf[0..length]).await?; + Ok(length) } } @@ -207,50 +203,46 @@ impl VariableInteger for u32 { Err(DeserializeError::MalformedPacket) } - fn read_async_variable_integer(stream: &mut S) -> impl Future> { - async move { - let mut integer = 0; - let mut length = 0; + async fn read_async_variable_integer(stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut integer = 0; + let mut length = 0; - for i in 0..4 { - let byte = stream.read_u8().await?; - length += 1; + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; - integer += (byte as u32 & 0x7f) << (7 * i); + integer += (byte as u32 & 0x7f) << (7 * i); - if (byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); } - Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } - fn write_async_variable_integer(&self, stream: &mut S) -> impl Future> { - async move { - let mut buf = [0u8; 4]; + async fn write_async_variable_integer(&self, stream: &mut S) -> Result { + let mut buf = [0u8; 4]; - if *self > 268_435_455 { - return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); - } + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); + } + + let mut write = *self; + let mut length = 1; - let mut write = *self; - let mut length = 1; - - for i in 0..4 { - let mut byte = (write % 128) as u8; - write /= 128; - if write > 0 { - byte |= 128; - } - buf[i] = byte; - if write == 0 { - length = i + 1; - break; - } + for (i, item) in buf.iter_mut().enumerate() { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + *item = byte; + if write == 0 { + length = i + 1; + break; } - stream.write_all(&buf[0..length]).await?; - Ok(length) } + stream.write_all(&buf[0..length]).await?; + Ok(length) } } diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 67e3450..3dd5ed7 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -64,45 +64,43 @@ impl PacketAsyncRead for PubComp where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let packet_identifier = stream.read_u16().await?; - if remaining_length == 2 { - return Ok(( - Self { - packet_identifier, - reason_code: PubCompReasonCode::Success, - properties: PubCompProperties::default(), - }, - 2, - )); - } - // Requires u16, u8 and at least 1 byte of variable integer prop length so at least 4 bytes - else if remaining_length < 4 { - return Err(ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4))); - } - - let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; - let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; - - let total_read_bytes = 2 + reason_code_read_bytes + properties_read_bytes; - - if total_read_bytes != remaining_length { - return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { - read: total_read_bytes, - remaining_length: remaining_length, - })); - } - - Ok(( + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let packet_identifier = stream.read_u16().await?; + if remaining_length == 2 { + return Ok(( Self { packet_identifier, - reason_code, - properties, + reason_code: PubCompReasonCode::Success, + properties: PubCompProperties::default(), }, - total_read_bytes, - )) + 2, + )); + } + // Requires u16, u8 and at least 1 byte of variable integer prop length so at least 4 bytes + else if remaining_length < 4 { + return Err(ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4))); + } + + let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; + + let total_read_bytes = 2 + reason_code_read_bytes + properties_read_bytes; + + if total_read_bytes != remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { + read: total_read_bytes, + remaining_length, + })); } + + Ok(( + Self { + packet_identifier, + reason_code, + properties, + }, + total_read_bytes, + )) } } @@ -126,22 +124,20 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for PubComp where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; - async move { - let mut total_writen_bytes = 2; - self.packet_identifier.async_write(stream).await?; - - if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); - } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; - } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; - } - Ok(total_writen_bytes) + let mut total_writen_bytes = 2; + self.packet_identifier.async_write(stream).await?; + + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_writen_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_writen_bytes += self.reason_code.async_write(stream).await?; + } else { + total_writen_bytes += self.reason_code.async_write(stream).await?; + total_writen_bytes += self.properties.async_write(stream).await?; } + Ok(total_writen_bytes) } } diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index 71dc8a3..004dafc 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -96,46 +96,44 @@ impl PacketAsyncRead for Publish where S: tokio::io::AsyncRead + Unpin, { - fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let dup = flags & 0b1000 != 0; - let qos = QoS::from_u8((flags & 0b110) >> 1)?; - let retain = flags & 0b1 != 0; - - let (topic, topic_read_bytes) = Box::::async_read(stream).await?; - total_read_bytes += topic_read_bytes; - let packet_identifier = if qos == QoS::AtMostOnce { - None - } else { - total_read_bytes += 2; - Some(stream.read_u16().await?) - }; - let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; - total_read_bytes += properties_read_bytes; - - if total_read_bytes > remaining_length { - return Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)); - } - let payload_len = remaining_length - total_read_bytes; - let mut payload = vec![0u8; payload_len]; - let payload_read_bytes = stream.read_exact(&mut payload).await?; - - assert_eq!(payload_read_bytes, payload_len); - - Ok(( - Self { - dup, - qos, - retain, - topic, - packet_identifier, - publish_properties, - payload, - }, - total_read_bytes + payload_read_bytes, - )) + async fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let dup = flags & 0b1000 != 0; + let qos = QoS::from_u8((flags & 0b110) >> 1)?; + let retain = flags & 0b1 != 0; + + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_bytes; + let packet_identifier = if qos == QoS::AtMostOnce { + None + } else { + total_read_bytes += 2; + Some(stream.read_u16().await?) + }; + let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; + + if total_read_bytes > remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)); } + let payload_len = remaining_length - total_read_bytes; + let mut payload = vec![0u8; payload_len]; + let payload_read_bytes = stream.read_exact(&mut payload).await?; + + assert_eq!(payload_read_bytes, payload_len); + + Ok(( + Self { + dup, + qos, + retain, + topic, + packet_identifier, + publish_properties, + payload, + }, + total_read_bytes + payload_read_bytes, + )) } } diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 9fad939..5f02125 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -65,36 +65,34 @@ impl PacketAsyncRead for PubRec where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - total_read_bytes += 2; - if remaining_length == 2 { - return Ok(( - Self { - packet_identifier, - reason_code: PubRecReasonCode::Success, - properties: PubRecProperties::default(), - }, - total_read_bytes, - )); - } - - let (reason_code, reason_code_read_bytes) = PubRecReasonCode::async_read(stream).await?; - let (properties, properties_read_bytes) = PubRecProperties::async_read(stream).await?; - - total_read_bytes += reason_code_read_bytes + properties_read_bytes; - - Ok(( + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; + if remaining_length == 2 { + return Ok(( Self { packet_identifier, - properties, - reason_code, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties::default(), }, total_read_bytes, - )) + )); } + + let (reason_code, reason_code_read_bytes) = PubRecReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubRecProperties::async_read(stream).await?; + + total_read_bytes += reason_code_read_bytes + properties_read_bytes; + + Ok(( + Self { + packet_identifier, + properties, + reason_code, + }, + total_read_bytes, + )) } } diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index 4bbd138..4f05abe 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -50,32 +50,30 @@ impl PacketAsyncRead for SubAck where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - let (properties, proproperties_read_bytes) = SubAckProperties::async_read(stream).await?; - total_read_bytes += 2 + proproperties_read_bytes; - let mut reason_codes = vec![]; - loop { - let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; - total_read_bytes += reason_code_read_bytes; - reason_codes.push(reason_code); - - if remaining_length == total_read_bytes { - break; - } - } + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, proproperties_read_bytes) = SubAckProperties::async_read(stream).await?; + total_read_bytes += 2 + proproperties_read_bytes; + let mut reason_codes = vec![]; + loop { + let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + reason_codes.push(reason_code); - Ok(( - Self { - packet_identifier, - properties, - reason_codes, - }, - total_read_bytes, - )) + if remaining_length == total_read_bytes { + break; + } } + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) } } @@ -96,22 +94,20 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for SubAck where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; - async move { - let mut total_written_bytes = 2; - stream.write_u16(self.packet_identifier).await?; - - total_written_bytes += self.properties.async_write(stream).await?; + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; - for reason_code in &self.reason_codes { - reason_code.async_write(stream).await?; - } - total_written_bytes += self.reason_codes.len(); + total_written_bytes += self.properties.async_write(stream).await?; - Ok(total_written_bytes) + for reason_code in &self.reason_codes { + reason_code.async_write(stream).await?; } + total_written_bytes += self.reason_codes.len(); + + Ok(total_written_bytes) } } diff --git a/mqrstt/src/packets/subscribe/mod.rs b/mqrstt/src/packets/subscribe/mod.rs index 28f4899..deb39ab 100644 --- a/mqrstt/src/packets/subscribe/mod.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -61,34 +61,32 @@ impl PacketAsyncRead for Subscribe where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - let (properties, proproperties_read_bytes) = SubscribeProperties::async_read(stream).await?; - total_read_bytes += 2 + proproperties_read_bytes; - - let mut topics = vec![]; - loop { - let (topic, topic_read_bytes) = Box::::async_read(stream).await?; - let (options, options_read_bytes) = SubscriptionOptions::async_read(stream).await?; - total_read_bytes += topic_read_bytes + options_read_bytes; - topics.push((topic, options)); - - if remaining_length >= total_read_bytes { - break; - } - } + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, proproperties_read_bytes) = SubscribeProperties::async_read(stream).await?; + total_read_bytes += 2 + proproperties_read_bytes; - Ok(( - Self { - packet_identifier, - properties, - topics, - }, - total_read_bytes, - )) + let mut topics = vec![]; + loop { + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + let (options, options_read_bytes) = SubscriptionOptions::async_read(stream).await?; + total_read_bytes += topic_read_bytes + options_read_bytes; + topics.push((topic, options)); + + if remaining_length >= total_read_bytes { + break; + } } + + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) } } @@ -110,20 +108,18 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for Subscribe where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; - async move { - let mut total_written_bytes = 2; - stream.write_u16(self.packet_identifier).await?; + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; - total_written_bytes += self.properties.async_write(stream).await?; - for (topic, options) in &self.topics { - total_written_bytes += topic.async_write(stream).await?; - total_written_bytes += options.async_write(stream).await?; - } - Ok(total_written_bytes) + total_written_bytes += self.properties.async_write(stream).await?; + for (topic, options) in &self.topics { + total_written_bytes += topic.async_write(stream).await?; + total_written_bytes += options.async_write(stream).await?; } + Ok(total_written_bytes) } } @@ -200,24 +196,22 @@ impl MqttAsyncRead for SubscriptionOptions where S: tokio::io::AsyncRead + Unpin, { - fn async_read(stream: &mut S) -> impl std::future::Future> { - async move { - let byte = stream.read_u8().await?; + async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let byte = stream.read_u8().await?; - let retain_handling_part = (byte & 0b00110000) >> 4; - let retain_as_publish_part = (byte & 0b00001000) >> 3; - let no_local_part = (byte & 0b00000100) >> 2; - let qos_part = byte & 0b00000011; + let retain_handling_part = (byte & 0b00110000) >> 4; + let retain_as_publish_part = (byte & 0b00001000) >> 3; + let no_local_part = (byte & 0b00000100) >> 2; + let qos_part = byte & 0b00000011; - let options = Self { - retain_handling: RetainHandling::from_u8(retain_handling_part)?, - retain_as_publish: retain_as_publish_part != 0, - no_local: no_local_part != 0, - qos: QoS::from_u8(qos_part)?, - }; + let options = Self { + retain_handling: RetainHandling::from_u8(retain_handling_part)?, + retain_as_publish: retain_as_publish_part != 0, + no_local: no_local_part != 0, + qos: QoS::from_u8(qos_part)?, + }; - Ok((options, 1)) - } + Ok((options, 1)) } } @@ -386,7 +380,7 @@ where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &[T; S]) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } // -------------------- Slices -------------------- @@ -395,7 +389,7 @@ where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &[T]) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } // -------------------- Vecs -------------------- @@ -404,7 +398,7 @@ where T: IntoSingleSubscription, { fn from(value: Vec) -> Self { - Self(value.into_iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.into_iter().map(IntoSingleSubscription::into).collect()) } } impl From<&Vec> for SubscribeTopics @@ -412,7 +406,7 @@ where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &Vec) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs index da06d52..3153f3b 100644 --- a/mqrstt/src/packets/unsuback/mod.rs +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -49,36 +49,34 @@ impl PacketAsyncRead for UnsubAck where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - total_read_bytes += 2; + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; - let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; - total_read_bytes += properties_read_bytes; + let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; - let mut reason_codes = vec![]; - loop { - let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; - total_read_bytes += reason_code_read_bytes; + let mut reason_codes = vec![]; + loop { + let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; - reason_codes.push(reason_code); + reason_codes.push(reason_code); - if total_read_bytes >= remaining_length { - break; - } + if total_read_bytes >= remaining_length { + break; } - - Ok(( - Self { - packet_identifier, - properties, - reason_codes, - }, - total_read_bytes, - )) } + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) } } diff --git a/mqrstt/src/packets/unsubscribe/mod.rs b/mqrstt/src/packets/unsubscribe/mod.rs index 6e9af89..9051f24 100644 --- a/mqrstt/src/packets/unsubscribe/mod.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -58,34 +58,32 @@ impl PacketAsyncRead for Unsubscribe where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - let (properties, properties_read_bytes) = UnsubscribeProperties::async_read(stream).await?; - total_read_bytes += 2 + properties_read_bytes; - - let mut topics = vec![]; - loop { - let (topic, topic_read_size) = Box::::async_read(stream).await?; - total_read_bytes += topic_read_size; - - topics.push(topic); - - if total_read_bytes >= remaining_length { - break; - } - } + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, properties_read_bytes) = UnsubscribeProperties::async_read(stream).await?; + total_read_bytes += 2 + properties_read_bytes; + + let mut topics = vec![]; + loop { + let (topic, topic_read_size) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_size; + + topics.push(topic); - Ok(( - Self { - packet_identifier, - properties, - topics, - }, - total_read_bytes, - )) + if total_read_bytes >= remaining_length { + break; + } } + + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) } } @@ -105,19 +103,18 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for Unsubscribe where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; - async move { - let mut total_written_bytes = 2; - stream.write_u16(self.packet_identifier).await?; - for topic in &self.topics { - total_written_bytes += topic.async_write(stream).await?; - } + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; - Ok(total_written_bytes) + for topic in &self.topics { + total_written_bytes += topic.async_write(stream).await?; } + + Ok(total_written_bytes) } } @@ -206,7 +203,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &[T; S]) -> Self { - Self(value.iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } // -------------------- Slices -------------------- @@ -215,7 +212,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &[T]) -> Self { - Self(value.iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } impl From<&[&str]> for UnsubscribeTopics { @@ -239,7 +236,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: Vec) -> Self { - Self(value.into_iter().map(|val| IntoUnsubscribeTopic::into(&val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } @@ -248,7 +245,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &Vec) -> Self { - Self(value.into_iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } diff --git a/mqrstt/src/state.rs b/mqrstt/src/state.rs index 3c52951..4ece423 100644 --- a/mqrstt/src/state.rs +++ b/mqrstt/src/state.rs @@ -34,7 +34,7 @@ pub struct State { impl State { pub fn new(receive_maximum: u16, apkid: AvailablePacketIds) -> Self { - let state = Self { + Self { apkid, outgoing_sub: Mutex::new(BTreeSet::new()), @@ -45,9 +45,7 @@ impl State { outgoing_pub_order: Mutex::new(VecDeque::new()), outgoing_rel: Mutex::new(BTreeSet::new()), incoming_pub: Mutex::new(BTreeSet::new()), - }; - - state + } } pub fn make_pkid_available(&self, pkid: u16) -> Result<(), HandlerError> { diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index 3caa0e4..bdfb57b 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -96,13 +96,10 @@ where return Err(ConnectionError::NoNetwork); } - match self.tokio_select(handler).await { - otherwise => { - self.network = None; + let result = self.tokio_select(handler).await; + self.network = None; - otherwise - } - } + result } async fn tokio_select(&mut self, handler: &mut H) -> Result { diff --git a/mqrstt/src/tokio/stream.rs b/mqrstt/src/tokio/stream.rs index a5647e7..9a720ad 100644 --- a/mqrstt/src/tokio/stream.rs +++ b/mqrstt/src/tokio/stream.rs @@ -19,51 +19,47 @@ impl StreamExt for S where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, { - fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future> { - async move { - let connect = options.create_connect_from_options(); + async fn connect(&mut self, options: &ConnectOptions) -> Result { + let connect = options.create_connect_from_options(); - self.write_packet(&connect).await?; + self.write_packet(&connect).await?; - let packet = Packet::async_read(self).await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - #[cfg(feature = "logs")] - trace!("Connected to server"); - Ok(con) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } + let packet = Packet::async_read(self).await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + #[cfg(feature = "logs")] + trace!("Connected to server"); + Ok(con) } else { - Err(ConnectionError::NotConnAck(packet)) + Err(ConnectionError::ConnectionRefused(con.reason_code)) } + } else { + Err(ConnectionError::NotConnAck(packet)) } } - fn read_packet(&mut self) -> impl std::future::Future> { - async move { Ok(Packet::async_read(self).await?) } + async fn read_packet(&mut self) -> Result { + Ok(Packet::async_read(self).await?) } - fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future> { - async move { - match packet.async_write(self).await { - Ok(_) => (), - Err(err) => { - return match err { - crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), - crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), - } + async fn write_packet(&mut self, packet: &Packet) -> Result<(), ConnectionError> { + match packet.async_write(self).await { + Ok(_) => (), + Err(err) => { + return match err { + crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), + crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), } } + } - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); + #[cfg(feature = "logs")] + trace!("Sending packet {}", packet); - self.flush().await?; - // self.flush_packets().await?; + self.flush().await?; + // self.flush_packets().await?; - Ok(()) - } + Ok(()) } async fn write_packets(&mut self, packets: &[Packet]) -> Result<(), ConnectionError> { From fd71dfdcfc4e6bd5c4ef64a45d545fb94c94d8d2 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 01:26:13 +0100 Subject: [PATCH 36/45] update cargo deny --- deny.toml | 3 +-- examples/tcp/Cargo.toml | 1 + fuzz/Cargo.toml | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deny.toml b/deny.toml index ee3595b..acabff1 100644 --- a/deny.toml +++ b/deny.toml @@ -2,11 +2,10 @@ [licenses] confidence-threshold = 0.95 -allow = ["MPL-2.0", "Apache-2.0", "MIT", "BSD-3-Clause", "ISC", "Unicode-3.0"] +allow = ["MPL-2.0", "Apache-2.0", "MIT", "Unicode-3.0"] exceptions = [ { allow = ["Unicode-DFS-2016"], name = "unicode-ident" }, - { allow = ["OpenSSL"], name = "ring" } ] [[licenses.clarify]] diff --git a/examples/tcp/Cargo.toml b/examples/tcp/Cargo.toml index 78834ec..b42bc70 100644 --- a/examples/tcp/Cargo.toml +++ b/examples/tcp/Cargo.toml @@ -2,6 +2,7 @@ name = "tcp" version = "0.1.0" edition = "2021" +license = "MIT" [dependencies] smol = { version = "2" } diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index adb4c3e..efee995 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -3,6 +3,7 @@ name = "mqrstt-fuzz" version = "0.0.0" publish = false edition = "2021" +license = "MIT" [package.metadata] cargo-fuzz = true From 0a009ca7e44ee6b8c90a03e57edb2871f8499665 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 02:58:57 +0100 Subject: [PATCH 37/45] add suback and unsuback test cases --- mqrstt/src/lib.rs | 2 +- mqrstt/src/packets/auth/mod.rs | 6 +- mqrstt/src/packets/connack/mod.rs | 8 +- mqrstt/src/packets/connect/mod.rs | 16 +- .../src/packets/macros/properties_macros.rs | 172 +++++++++--------- mqrstt/src/packets/mod.rs | 133 +++++++++++--- mqrstt/src/packets/pubcomp/mod.rs | 12 +- mqrstt/src/packets/pubrec/mod.rs | 12 +- mqrstt/src/packets/pubrel/mod.rs | 12 +- mqrstt/src/packets/suback/mod.rs | 32 ++-- mqrstt/src/packets/suback/properties.rs | 3 +- mqrstt/src/packets/unsuback/mod.rs | 33 ++-- mqrstt/src/packets/unsuback/properties.rs | 2 +- mqrstt/src/tests/test_packets.rs | 54 ++++++ 14 files changed, 322 insertions(+), 175 deletions(-) diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index e4d0477..e5ac218 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -409,7 +409,7 @@ mod tokio_lib_test { network.connect(stream, &mut pingresp).await.unwrap(); let network_handle = tokio::task::spawn(async move { - let result = network.run(&mut pingresp).await; + let _result = network.run(&mut pingresp).await; // check result and or restart the connection pingresp }); diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs index e16fdda..6f55d4d 100644 --- a/mqrstt/src/packets/auth/mod.rs +++ b/mqrstt/src/packets/auth/mod.rs @@ -48,9 +48,9 @@ where S: tokio::io::AsyncWrite + Unpin, { async fn async_write(&self, stream: &mut S) -> Result { - let reason_code_writen = self.reason_code.async_write(stream).await?; - let properties_writen = self.properties.async_write(stream).await?; - Ok(reason_code_writen + properties_writen) + let reason_code_written = self.reason_code.async_write(stream).await?; + let properties_written = self.properties.async_write(stream).await?; + Ok(reason_code_written + properties_written) } } diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 94da71f..90dc715 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -85,11 +85,11 @@ where { async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; - let connack_flags_writen = self.connack_flags.async_write(stream).await?; - let reason_code_writen = self.reason_code.async_write(stream).await?; - let connack_properties_writen = self.connack_properties.async_write(stream).await?; + let connack_flags_written = self.connack_flags.async_write(stream).await?; + let reason_code_written = self.reason_code.async_write(stream).await?; + let connack_properties_written = self.connack_properties.async_write(stream).await?; - Ok(connack_flags_writen + reason_code_writen + connack_properties_writen) + Ok(connack_flags_written + reason_code_written + connack_properties_written) } } diff --git a/mqrstt/src/packets/connect/mod.rs b/mqrstt/src/packets/connect/mod.rs index 463f6e4..cb5a611 100644 --- a/mqrstt/src/packets/connect/mod.rs +++ b/mqrstt/src/packets/connect/mod.rs @@ -223,12 +223,12 @@ where use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; async move { - let mut total_writen_bytes = 6 // protocol header + let mut total_written_bytes = 6 // protocol header + 1 // protocol version + 1 // connect flags + 2; // keep alive let protocol = [0x00, 0x04, b'M', b'Q', b'T', b'T']; - // We allready start with 6 as total writen bytes thus dont add anymore + // We allready start with 6 as total written bytes thus dont add anymore stream.write_all(&protocol).await?; self.protocol_version.async_write(stream).await?; @@ -250,21 +250,21 @@ where stream.write_u16(self.keep_alive).await?; - total_writen_bytes += self.connect_properties.async_write(stream).await?; + total_written_bytes += self.connect_properties.async_write(stream).await?; - total_writen_bytes += self.client_id.async_write(stream).await?; + total_written_bytes += self.client_id.async_write(stream).await?; if let Some(last_will) = &self.last_will { - total_writen_bytes += last_will.async_write(stream).await?; + total_written_bytes += last_will.async_write(stream).await?; } if let Some(username) = &self.username { - total_writen_bytes += username.async_write(stream).await?; + total_written_bytes += username.async_write(stream).await?; } if let Some(password) = &self.password { - total_writen_bytes += password.async_write(stream).await?; + total_written_bytes += password.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } } diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index b71e149..33d80ba 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -35,13 +35,13 @@ macro_rules! define_properties { impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + Unpin { async fn async_write(&self, stream: &mut S) -> Result { - let mut bytes_writen = 0; + let mut bytes_written = 0; $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; $( - $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); + $crate::packets::macros::properties_write!(self, bytes_written, stream, PropertyType::$prop_variant); )* - Ok(bytes_writen) + Ok(bytes_written) } } @@ -556,173 +556,173 @@ macro_rules! properties_read_match_branch_name { } macro_rules! properties_write { - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::PayloadFormatIndicator) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::PayloadFormatIndicator) => { if let Some(payload_format_indicator) = &($self.payload_format_indicator) { - $bytes_writen += PropertyType::PayloadFormatIndicator.async_write($stream).await?; - $bytes_writen += payload_format_indicator.async_write($stream).await?; + $bytes_written += PropertyType::PayloadFormatIndicator.async_write($stream).await?; + $bytes_written += payload_format_indicator.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MessageExpiryInterval) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MessageExpiryInterval) => { if let Some(message_expiry_interval) = &($self.message_expiry_interval) { - $bytes_writen += PropertyType::MessageExpiryInterval.async_write($stream).await?; - $bytes_writen += message_expiry_interval.async_write($stream).await?; + $bytes_written += PropertyType::MessageExpiryInterval.async_write($stream).await?; + $bytes_written += message_expiry_interval.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ContentType) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ContentType) => { if let Some(content_type) = &($self.content_type) { - $bytes_writen += PropertyType::ContentType.async_write($stream).await?; - $bytes_writen += content_type.async_write($stream).await?; + $bytes_written += PropertyType::ContentType.async_write($stream).await?; + $bytes_written += content_type.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ResponseTopic) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ResponseTopic) => { if let Some(response_topic) = &($self.response_topic) { - $bytes_writen += PropertyType::ResponseTopic.async_write($stream).await?; - $bytes_writen += response_topic.as_ref().async_write($stream).await?; + $bytes_written += PropertyType::ResponseTopic.async_write($stream).await?; + $bytes_written += response_topic.as_ref().async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::CorrelationData) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::CorrelationData) => { if let Some(correlation_data) = &($self.correlation_data) { - $bytes_writen += PropertyType::CorrelationData.async_write($stream).await?; - $bytes_writen += correlation_data.async_write($stream).await?; + $bytes_written += PropertyType::CorrelationData.async_write($stream).await?; + $bytes_written += correlation_data.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SubscriptionIdentifier) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SubscriptionIdentifier) => { if let Some(sub_id) = &($self.subscription_identifier) { - $bytes_writen += PropertyType::SubscriptionIdentifier.async_write($stream).await?; - $bytes_writen += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + $bytes_written += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_written += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ListSubscriptionIdentifier) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ListSubscriptionIdentifier) => { for sub_id in &($self.subscription_identifiers) { - $bytes_writen += PropertyType::SubscriptionIdentifier.async_write($stream).await?; - $bytes_writen += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + $bytes_written += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_written += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SessionExpiryInterval) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SessionExpiryInterval) => { if let Some(session_expiry_interval) = &($self.session_expiry_interval) { - $bytes_writen += PropertyType::SessionExpiryInterval.async_write($stream).await?; - $bytes_writen += session_expiry_interval.async_write($stream).await?; + $bytes_written += PropertyType::SessionExpiryInterval.async_write($stream).await?; + $bytes_written += session_expiry_interval.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AssignedClientIdentifier) => {}; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ServerKeepAlive) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AssignedClientIdentifier) => {}; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ServerKeepAlive) => { if let Some(server_keep_alive) = &($self.server_keep_alive) { - $bytes_writen += PropertyType::ServerKeepAlive.async_write($stream).await?; - $bytes_writen += server_keep_alive.async_write($stream).await?; + $bytes_written += PropertyType::ServerKeepAlive.async_write($stream).await?; + $bytes_written += server_keep_alive.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AuthenticationMethod) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AuthenticationMethod) => { if let Some(authentication_method) = &($self.authentication_method) { - $bytes_writen += PropertyType::AuthenticationMethod.async_write($stream).await?; - $bytes_writen += authentication_method.async_write($stream).await?; + $bytes_written += PropertyType::AuthenticationMethod.async_write($stream).await?; + $bytes_written += authentication_method.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AuthenticationData) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AuthenticationData) => { if let Some(authentication_data) = &($self.authentication_data) { if !authentication_data.is_empty() && ($self.authentication_method).is_some() { - $bytes_writen += PropertyType::AuthenticationData.async_write($stream).await?; - $bytes_writen += authentication_data.async_write($stream).await?; + $bytes_written += PropertyType::AuthenticationData.async_write($stream).await?; + $bytes_written += authentication_data.async_write($stream).await?; } } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RequestProblemInformation) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RequestProblemInformation) => { if let Some(request_problem_information) = &($self.request_problem_information) { - $bytes_writen += PropertyType::RequestProblemInformation.async_write($stream).await?; - $bytes_writen += request_problem_information.async_write($stream).await?; + $bytes_written += PropertyType::RequestProblemInformation.async_write($stream).await?; + $bytes_written += request_problem_information.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::WillDelayInterval) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::WillDelayInterval) => { if let Some(delay_interval) = &($self.will_delay_interval) { - $bytes_writen += PropertyType::WillDelayInterval.async_write($stream).await?; - $bytes_writen += delay_interval.async_write($stream).await?; + $bytes_written += PropertyType::WillDelayInterval.async_write($stream).await?; + $bytes_written += delay_interval.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RequestResponseInformation) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RequestResponseInformation) => { if let Some(request_response_information) = &($self.request_response_information) { - $bytes_writen += PropertyType::RequestResponseInformation.async_write($stream).await?; - $bytes_writen += request_response_information.async_write($stream).await?; + $bytes_written += PropertyType::RequestResponseInformation.async_write($stream).await?; + $bytes_written += request_response_information.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ResponseInformation) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ResponseInformation) => { if let Some(response_info) = &($self.response_info) { - $bytes_writen += PropertyType::ResponseInformation.async_write($stream).await?; - $bytes_writen += response_info.async_write($stream).await?; + $bytes_written += PropertyType::ResponseInformation.async_write($stream).await?; + $bytes_written += response_info.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ServerReference) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ServerReference) => { if let Some(server_refrence) = &($self.server_reference) { - $bytes_writen += PropertyType::ServerReference.async_write($stream).await?; + $bytes_written += PropertyType::ServerReference.async_write($stream).await?; server_refrence.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ReasonString) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ReasonString) => { if let Some(reason_string) = &($self.reason_string) { - $bytes_writen += PropertyType::ReasonString.async_write($stream).await?; - $bytes_writen += reason_string.async_write($stream).await?; + $bytes_written += PropertyType::ReasonString.async_write($stream).await?; + $bytes_written += reason_string.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ReceiveMaximum) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ReceiveMaximum) => { if let Some(receive_maximum) = &($self.receive_maximum) { - $bytes_writen += PropertyType::ReceiveMaximum.async_write($stream).await?; - $bytes_writen += receive_maximum.async_write($stream).await?; + $bytes_written += PropertyType::ReceiveMaximum.async_write($stream).await?; + $bytes_written += receive_maximum.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::TopicAliasMaximum) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::TopicAliasMaximum) => { if let Some(topic_alias_maximum) = &($self.topic_alias_maximum) { - $bytes_writen += PropertyType::TopicAliasMaximum.async_write($stream).await?; - $bytes_writen += topic_alias_maximum.async_write($stream).await?; + $bytes_written += PropertyType::TopicAliasMaximum.async_write($stream).await?; + $bytes_written += topic_alias_maximum.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::TopicAlias) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::TopicAlias) => { if let Some(topic_alias) = &($self.topic_alias) { - $bytes_writen += PropertyType::TopicAlias.async_write($stream).await?; - $bytes_writen += topic_alias.async_write($stream).await?; + $bytes_written += PropertyType::TopicAlias.async_write($stream).await?; + $bytes_written += topic_alias.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MaximumQos) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MaximumQos) => { if let Some(maximum_qos) = &($self.maximum_qos) { - $bytes_writen += PropertyType::MaximumQos.async_write($stream).await?; - $bytes_writen += maximum_qos.async_write($stream).await?; + $bytes_written += PropertyType::MaximumQos.async_write($stream).await?; + $bytes_written += maximum_qos.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RetainAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RetainAvailable) => { if let Some(retain_available) = &($self.retain_available) { - $bytes_writen += PropertyType::RetainAvailable.async_write($stream).await?; - $bytes_writen += retain_available.async_write($stream).await?; + $bytes_written += PropertyType::RetainAvailable.async_write($stream).await?; + $bytes_written += retain_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::UserProperty) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::UserProperty) => { for (key, value) in &($self.user_properties) { - $bytes_writen += PropertyType::UserProperty.async_write($stream).await?; - $bytes_writen += key.async_write($stream).await?; - $bytes_writen += value.async_write($stream).await?; + $bytes_written += PropertyType::UserProperty.async_write($stream).await?; + $bytes_written += key.async_write($stream).await?; + $bytes_written += value.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MaximumPacketSize) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MaximumPacketSize) => { if let Some(maximum_packet_size) = &($self.maximum_packet_size) { - $bytes_writen += PropertyType::MaximumPacketSize.async_write($stream).await?; - $bytes_writen += maximum_packet_size.async_write($stream).await?; + $bytes_written += PropertyType::MaximumPacketSize.async_write($stream).await?; + $bytes_written += maximum_packet_size.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::WildcardSubscriptionAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::WildcardSubscriptionAvailable) => { if let Some(wildcards_available) = &($self.wildcards_available) { - $bytes_writen += PropertyType::WildcardSubscriptionAvailable.async_write($stream).await?; - $bytes_writen += wildcards_available.async_write($stream).await?; + $bytes_written += PropertyType::WildcardSubscriptionAvailable.async_write($stream).await?; + $bytes_written += wildcards_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SubscriptionIdentifierAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SubscriptionIdentifierAvailable) => { if let Some(subscription_ids_available) = &($self.subscription_ids_available) { - $bytes_writen += PropertyType::SubscriptionIdentifierAvailable.async_write($stream).await?; - $bytes_writen += subscription_ids_available.async_write($stream).await?; + $bytes_written += PropertyType::SubscriptionIdentifierAvailable.async_write($stream).await?; + $bytes_written += subscription_ids_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SharedSubscriptionAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SharedSubscriptionAvailable) => { if let Some(shared_subscription_available) = &($self.shared_subscription_available) { - $bytes_writen += PropertyType::SharedSubscriptionAvailable.async_write($stream).await?; - $bytes_writen += shared_subscription_available.async_write($stream).await?; + $bytes_written += PropertyType::SharedSubscriptionAvailable.async_write($stream).await?; + $bytes_written += shared_subscription_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, $unknown:ident) => { + ($self:ident, $bytes_written:ident, $stream:ident, $unknown:ident) => { compile_error!(concat!("Unknown property: ", stringify!($unknown))); }; } diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 2eb135a..39fbc88 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -103,9 +103,7 @@ impl Packet { Packet::PubRel(_) => 0b0110_0010, Packet::PubComp(_) => 0b0111_0000, Packet::Subscribe(_) => 0b1000_0010, - Packet::SubAck(_) => { - unreachable!() - } + Packet::SubAck(_) => 0b1001_0000, Packet::Unsubscribe(_) => 0b1010_0010, Packet::UnsubAck(_) => 0b1011_0000, Packet::PingReq => 0b1100_0000, @@ -168,17 +166,20 @@ impl Packet { p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } - Packet::SubAck(_) => { - unreachable!() + Packet::SubAck(p) => { + buf.put_u8(0b1001_0000); + p.wire_len().write_variable_integer(buf)?; + p.write(buf)?; } Packet::Unsubscribe(p) => { buf.put_u8(0b1010_0010); p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } - Packet::UnsubAck(_) => { - unreachable!(); - // buf.put_u8(0b1011_0000); + Packet::UnsubAck(p) => { + buf.put_u8(0b1011_0000); + p.wire_len().write_variable_integer(buf)?; + p.write(buf)?; } Packet::PingReq => { buf.put_u8(0b1100_0000); @@ -259,17 +260,20 @@ impl Packet { written += p.wire_len().write_async_variable_integer(stream).await?; written += p.async_write(stream).await?; } - Packet::SubAck(_) => { - unreachable!() + Packet::SubAck(p) => { + stream.write_u8(0b1001_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; } Packet::Unsubscribe(p) => { stream.write_u8(0b1010_0010).await?; written += p.wire_len().write_async_variable_integer(stream).await?; written += p.async_write(stream).await?; } - Packet::UnsubAck(_) => { - unreachable!(); - // stream.write_u8(0b1011_0000).await?; + Packet::UnsubAck(p) => { + stream.write_u8(0b1011_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; } Packet::PingReq => { stream.write_u8(0b1100_0000).await?; @@ -396,6 +400,28 @@ impl Display for Packet { } } +impl WireLength for Packet { + fn wire_len(&self) -> usize { + match self { + Packet::Connect(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::ConnAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Publish(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubRec(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubRel(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubComp(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Subscribe(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::SubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Unsubscribe(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::UnsubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PingReq => 2, + Packet::PingResp => 2, + Packet::Disconnect(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Auth(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + } + } +} + /// 2.1.2 MQTT Control Packet type #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] pub enum PacketType { @@ -447,6 +473,7 @@ impl std::fmt::Display for PacketType { #[cfg(test)] mod tests { + use bytes::BytesMut; use crate::packets::Packet; @@ -454,26 +481,39 @@ mod tests { use crate::tests::test_packets::*; #[rstest::rstest] - #[case(ping_req_case().1)] - #[case(ping_resp_case().1)] - #[case(connack_case().1)] - #[case(create_subscribe_packet(1))] - #[case(create_subscribe_packet(65335))] - #[case(create_puback_packet(1))] - #[case(create_puback_packet(65335))] - #[case(create_disconnect_packet())] - #[case(create_connack_packet(true))] - #[case(create_connack_packet(false))] - #[case(publish_packet_1())] - #[case(publish_packet_2())] - #[case(publish_packet_3())] - #[case(publish_packet_4())] - #[case(create_empty_publish_packet())] + #[case::ping_req_case(ping_req_case().1)] + #[case::ping_resp_case(ping_resp_case().1)] + #[case::connack_case(connack_case().1)] + #[case::create_subscribe_packet(create_subscribe_packet(1))] + #[case::create_subscribe_packet(create_subscribe_packet(65335))] + #[case::create_puback_packet(create_puback_packet(1))] + #[case::create_puback_packet(create_puback_packet(65335))] + #[case::create_disconnect_packet(create_disconnect_packet())] + #[case::create_connack_packet(create_connack_packet(true))] + #[case::create_connack_packet(create_connack_packet(false))] + #[case::publish_packet_1(publish_packet_1())] + #[case::publish_packet_2(publish_packet_2())] + #[case::publish_packet_3(publish_packet_3())] + #[case::publish_packet_4(publish_packet_4())] + #[case::create_empty_publish_packet(create_empty_publish_packet())] + #[case::subscribe(subscribe_case())] + #[case::suback(suback_case())] + #[case::unsubscribe(unsubscribe_case())] + #[case::unsuback(unsuback_case())] fn test_write_read_write_read_cases(#[case] packet: Packet) { + use crate::packets::WireLength; + let mut buffer = BytesMut::new(); + packet.write(&mut buffer).unwrap(); + + let wire_len = packet.wire_len(); + assert_eq!(wire_len, buffer.len()); + let res1 = Packet::read(&mut buffer).unwrap(); + assert_eq!(packet, res1); + let mut buffer = BytesMut::new(); res1.write(&mut buffer).unwrap(); let res2 = Packet::read(&mut buffer).unwrap(); @@ -533,6 +573,43 @@ mod tests { assert_eq!(out, input) } + #[rstest::rstest] + #[case::ping_req_case(ping_req_case().1)] + #[case::ping_resp_case(ping_resp_case().1)] + #[case::connack_case(connack_case().1)] + #[case::create_subscribe_packet(create_subscribe_packet(1))] + #[case::create_subscribe_packet(create_subscribe_packet(65335))] + #[case::create_puback_packet(create_puback_packet(1))] + #[case::create_puback_packet(create_puback_packet(65335))] + #[case::create_disconnect_packet(create_disconnect_packet())] + #[case::create_connack_packet(create_connack_packet(true))] + #[case::create_connack_packet(create_connack_packet(false))] + #[case::publish_packet_1(publish_packet_1())] + #[case::publish_packet_2(publish_packet_2())] + #[case::publish_packet_3(publish_packet_3())] + #[case::publish_packet_4(publish_packet_4())] + #[case::create_empty_publish_packet(create_empty_publish_packet())] + #[case::subscribe(subscribe_case())] + #[case::suback(suback_case())] + #[case::unsubscribe(unsubscribe_case())] + #[case::unsuback(unsuback_case())] + #[tokio::test] + async fn test_async_write_read_write_read_cases(#[case] packet: Packet) { + use crate::packets::WireLength; + + let mut buffer = Vec::with_capacity(1000); + packet.async_write(&mut buffer).await.unwrap(); + + let wire_len = packet.wire_len(); + assert_eq!(wire_len, buffer.len()); + + let mut buf = buffer.as_slice(); + + let res1 = Packet::async_read(&mut buf).await.unwrap(); + + assert_eq!(packet, res1); + } + // #[rstest::rstest] // #[case(&[59, 1, 0, 59])] // #[case(&[16, 14, 0, 4, 77, 81, 84, 84, 5, 247, 247, 252, 1, 17, 247, 247, 247])] diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 3dd5ed7..459b036 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -126,18 +126,18 @@ where { async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; - let mut total_writen_bytes = 2; + let mut total_written_bytes = 2; self.packet_identifier.async_write(stream).await?; if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); + return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 5f02125..92ca8be 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -118,18 +118,18 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_writen_bytes = 2; + let mut total_written_bytes = 2; self.packet_identifier.async_write(stream).await?; if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); + return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } } diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index c97d285..03d78d9 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -116,18 +116,18 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_writen_bytes = 2; + let mut total_written_bytes = 2; self.packet_identifier.async_write(stream).await?; if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); + return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } } diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index 4f05abe..df8b1a2 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -1,4 +1,5 @@ mod properties; + pub use properties::SubAckProperties; mod reason_code; @@ -7,6 +8,7 @@ pub use reason_code::SubAckReasonCode; use super::{ error::SerializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + VariableInteger, WireLength, }; use bytes::BufMut; use tokio::io::AsyncReadExt; @@ -23,19 +25,21 @@ pub struct SubAck { } impl PacketRead for SubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = SubAckProperties::read(&mut buf)?; let mut reason_codes = vec![]; - loop { - let reason_code = SubAckReasonCode::read(&mut buf)?; - - reason_codes.push(reason_code); - if buf.is_empty() { + let mut read = 2 + properties.wire_len().variable_integer_len() + properties.wire_len(); + loop { + if read >= remaining_length { break; } + + let reason_code = SubAckReasonCode::read(&mut buf)?; + reason_codes.push(reason_code); + read += 1; } Ok(Self { @@ -57,13 +61,13 @@ where total_read_bytes += 2 + proproperties_read_bytes; let mut reason_codes = vec![]; loop { - let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; - total_read_bytes += reason_code_read_bytes; - reason_codes.push(reason_code); - if remaining_length == total_read_bytes { break; } + + let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + reason_codes.push(reason_code); } Ok(( @@ -111,6 +115,12 @@ where } } +impl WireLength for SubAck { + fn wire_len(&self) -> usize { + 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + self.reason_codes.len() + } +} + #[cfg(test)] mod test { use bytes::BytesMut; @@ -128,7 +138,7 @@ mod test { ]; let data = BytesMut::from(&buf[..]); - let sub_ack = SubAck::read(0, 0, data.clone().into()).unwrap(); + let sub_ack = SubAck::read(0, 5, data.clone().into()).unwrap(); let mut result = BytesMut::new(); sub_ack.write(&mut result).unwrap(); diff --git a/mqrstt/src/packets/suback/properties.rs b/mqrstt/src/packets/suback/properties.rs index be7ac8b..15f6997 100644 --- a/mqrstt/src/packets/suback/properties.rs +++ b/mqrstt/src/packets/suback/properties.rs @@ -44,7 +44,7 @@ impl MqttRead for SubAckProperties { e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::SubAck)), } - if buf.is_empty() { + if properties_data.is_empty() { break; } } @@ -58,7 +58,6 @@ impl MqttWrite for SubAckProperties { if let Some(sub_id) = self.subscription_identifier { PropertyType::SubscriptionIdentifier.write(buf)?; sub_id.write_variable_integer(buf)?; - // sub_id.write_variable_integer(buf)?; } for (key, value) in &self.user_properties { PropertyType::UserProperty.write(buf)?; diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs index 3153f3b..9823bed 100644 --- a/mqrstt/src/packets/unsuback/mod.rs +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -12,7 +12,7 @@ use tokio::io::AsyncReadExt; use super::error::SerializeError; use super::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite}; -use super::PacketAsyncRead; +use super::{PacketAsyncRead, VariableInteger, WireLength}; /// UnsubAck packet is sent by the server in response to an [`crate::packets::Unsubscribe`] packet. #[derive(Debug, Default, PartialEq, Eq, Clone)] @@ -23,18 +23,20 @@ pub struct UnsubAck { } impl PacketRead for UnsubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = UnsubAckProperties::read(&mut buf)?; let mut reason_codes = vec![]; - loop { - let reason_code = UnsubAckReasonCode::read(&mut buf)?; - reason_codes.push(reason_code); - - if buf.is_empty() { + let mut read = 2 + properties.wire_len().variable_integer_len() + properties.wire_len(); + loop { + if read == remaining_length { break; } + + let reason_code = UnsubAckReasonCode::read(&mut buf)?; + reason_codes.push(reason_code); + read += 1; } Ok(Self { @@ -50,23 +52,22 @@ where S: tokio::io::AsyncRead + Unpin, { async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { - let mut total_read_bytes = 0; + let mut total_read_bytes = 2; let packet_identifier = stream.read_u16().await?; - total_read_bytes += 2; let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; total_read_bytes += properties_read_bytes; let mut reason_codes = vec![]; loop { + if total_read_bytes >= remaining_length { + break; + } + let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; total_read_bytes += reason_code_read_bytes; reason_codes.push(reason_code); - - if total_read_bytes >= remaining_length { - break; - } } Ok(( @@ -114,6 +115,12 @@ where } } +impl WireLength for UnsubAck { + fn wire_len(&self) -> usize { + 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + self.reason_codes.len() + } +} + #[cfg(test)] mod tests { use bytes::{Bytes, BytesMut}; diff --git a/mqrstt/src/packets/unsuback/properties.rs b/mqrstt/src/packets/unsuback/properties.rs index 85dcab0..d462c50 100644 --- a/mqrstt/src/packets/unsuback/properties.rs +++ b/mqrstt/src/packets/unsuback/properties.rs @@ -38,7 +38,7 @@ impl MqttRead for UnsubAckProperties { e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::UnsubAck)), } - if buf.is_empty() { + if properties_data.is_empty() { break; } } diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index c8351e5..2f9dac1 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -276,6 +276,60 @@ pub fn create_disconnect_packet() -> Packet { }) } +pub fn suback_case() -> Packet { + let expected = SubAck { + packet_identifier: 3, + reason_codes: vec![SubAckReasonCode::GrantedQoS0, SubAckReasonCode::GrantedQoS1, SubAckReasonCode::GrantedQoS2], + properties: SubAckProperties { + user_properties: vec![(String::from("test").into(), String::from("test").into())], + subscription_identifier: Some(2000), + }, + }; + + Packet::SubAck(expected) +} + +pub fn subscribe_case() -> Packet { + let expected = Subscribe { + packet_identifier: 3, + topics: vec![("test/topic".into(), SubscriptionOptions::default())], + properties: SubscribeProperties { + user_properties: vec![(String::from("test").into(), String::from("test").into())], + subscription_identifier: Some(2000), + }, + }; + + Packet::Subscribe(expected) +} + +// return a crazy big packet +pub fn unsuback_case() -> Packet { + let expected = UnsubAck { + packet_identifier: 3, + reason_codes: vec![ + UnsubAckReasonCode::NoSubscriptionExisted, + UnsubAckReasonCode::UnspecifiedError, + UnsubAckReasonCode::ImplementationSpecificError, + ], + properties: UnsubAckProperties { + user_properties: vec![], + reason_string: None, + }, + }; + + Packet::UnsubAck(expected) +} + +pub fn unsubscribe_case() -> Packet { + let expected = Unsubscribe { + packet_identifier: 3, + topics: vec!["test/topic".into()], + properties: UnsubscribeProperties { user_properties: vec![] }, + }; + + Packet::Unsubscribe(expected) +} + #[rstest] #[case(create_subscribe_packet(1))] #[case(create_subscribe_packet(65335))] From 621b5178908daea8f22baf3f25e15c198c13c734 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 03:28:09 +0100 Subject: [PATCH 38/45] fix write unsubcribe --- mqrstt/src/packets/mod.rs | 7 ++++++- mqrstt/src/packets/unsubscribe/mod.rs | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 39fbc88..8d6410e 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -509,6 +509,9 @@ mod tests { let wire_len = packet.wire_len(); assert_eq!(wire_len, buffer.len()); + dbg!(wire_len); + let a: Vec<_> = buffer.iter().map(|f| *f as u16).collect(); + println!("{:?}", a); let res1 = Packet::read(&mut buffer).unwrap(); @@ -598,9 +601,11 @@ mod tests { use crate::packets::WireLength; let mut buffer = Vec::with_capacity(1000); - packet.async_write(&mut buffer).await.unwrap(); + let res = packet.async_write(&mut buffer).await.unwrap(); let wire_len = packet.wire_len(); + + assert_eq!(res, wire_len); assert_eq!(wire_len, buffer.len()); let mut buf = buffer.as_slice(); diff --git a/mqrstt/src/packets/unsubscribe/mod.rs b/mqrstt/src/packets/unsubscribe/mod.rs index 9051f24..e316623 100644 --- a/mqrstt/src/packets/unsubscribe/mod.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -110,6 +110,8 @@ where let mut total_written_bytes = 2; stream.write_u16(self.packet_identifier).await?; + total_written_bytes += self.properties.async_write(stream).await?; + for topic in &self.topics { total_written_bytes += topic.async_write(stream).await?; } From 5caa7011494a8cd429ef41d13e8d0b8b794520e2 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:58:22 +0100 Subject: [PATCH 39/45] fix wire_length and write length --- mqrstt/src/packets/macros/properties_macros.rs | 2 +- mqrstt/src/packets/mod.rs | 4 +++- mqrstt/src/tests/test_packets.rs | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index 33d80ba..b7b8438 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -36,7 +36,7 @@ macro_rules! define_properties { impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + Unpin { async fn async_write(&self, stream: &mut S) -> Result { let mut bytes_written = 0; - $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; + bytes_written += $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; $( $crate::packets::macros::properties_write!(self, bytes_written, stream, PropertyType::$prop_variant); )* diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 8d6410e..7fd4a0a 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -278,10 +278,12 @@ impl Packet { Packet::PingReq => { stream.write_u8(0b1100_0000).await?; stream.write_u8(0).await?; // Variable header length. + written += 1; } Packet::PingResp => { stream.write_u8(0b1101_0000).await?; stream.write_u8(0).await?; // Variable header length. + written += 1; } Packet::Disconnect(p) => { stream.write_u8(0b1110_0000).await?; @@ -605,7 +607,7 @@ mod tests { let wire_len = packet.wire_len(); - assert_eq!(res, wire_len); + assert_eq!(res, buffer.len()); assert_eq!(wire_len, buffer.len()); let mut buf = buffer.as_slice(); diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index 2f9dac1..ffc8f18 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -324,7 +324,7 @@ pub fn unsubscribe_case() -> Packet { let expected = Unsubscribe { packet_identifier: 3, topics: vec!["test/topic".into()], - properties: UnsubscribeProperties { user_properties: vec![] }, + properties: UnsubscribeProperties { user_properties: vec![("written += 1;".into(), "value".into())] }, }; Packet::Unsubscribe(expected) From 3439e958d96715e78ed8327994bf6dfb33daff11 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 10:21:06 +0100 Subject: [PATCH 40/45] examples in readme and example dir --- examples/tcp/Cargo.toml | 9 ++ examples/tcp/src/ping_pong.rs | 54 ++++++++++ examples/tcp/src/ping_pong_smol.rs | 52 ++++++++++ examples/tcp/src/tokio.rs | 1 + README.md => mqrstt/README.md | 160 +++++++++++------------------ mqrstt/src/packets/error.rs | 2 +- mqrstt/src/tokio/network.rs | 2 + 7 files changed, 179 insertions(+), 101 deletions(-) create mode 100644 examples/tcp/src/ping_pong.rs create mode 100644 examples/tcp/src/ping_pong_smol.rs rename README.md => mqrstt/README.md (61%) diff --git a/examples/tcp/Cargo.toml b/examples/tcp/Cargo.toml index b42bc70..15e0444 100644 --- a/examples/tcp/Cargo.toml +++ b/examples/tcp/Cargo.toml @@ -6,6 +6,7 @@ license = "MIT" [dependencies] smol = { version = "2" } +futures = "0.3.31" tokio = { version = "1", features = ["full"] } @@ -15,6 +16,14 @@ mqrstt = { path = "../../mqrstt", features = ["logs"] } name = "tokio" path = "src/tokio.rs" +[[bin]] +name = "ping_pong" +path = "src/ping_pong.rs" + +[[bin]] +name = "ping_pong_smol" +path = "src/ping_pong_smol.rs" + [[bin]] name = "smol" path = "src/smol.rs" diff --git a/examples/tcp/src/ping_pong.rs b/examples/tcp/src/ping_pong.rs new file mode 100644 index 0000000..3081554 --- /dev/null +++ b/examples/tcp/src/ping_pong.rs @@ -0,0 +1,54 @@ +use mqrstt::{ + packets::{self, Packet}, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, +}; +use tokio::time::Duration; + +pub struct PingPong { + pub client: MqttClient, +} +impl AsyncEventHandler for PingPong { + // Handlers only get INCOMING packets. + async fn handle(&mut self, event: packets::Packet) { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); + println!("Received Ping, Send pong!"); + } + } + } + Packet::ConnAck(_) => { + println!("Connected!") + } + _ => (), + } + } +} + +#[tokio::main] +async fn main() { + let (mut network, client) = NetworkBuilder::new_from_client_id("TokioTcpPingPongExample").tokio_network(); + + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + let stream = tokio::io::BufStream::new(stream); + + let mut pingpong = PingPong { client: client.clone() }; + + network.connect(stream, &mut pingpong).await.unwrap(); + + client.subscribe("mqrstt").await.unwrap(); + + let network_handle = tokio::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + tokio::time::sleep(Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = network_handle.await.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); +} diff --git a/examples/tcp/src/ping_pong_smol.rs b/examples/tcp/src/ping_pong_smol.rs new file mode 100644 index 0000000..2647daa --- /dev/null +++ b/examples/tcp/src/ping_pong_smol.rs @@ -0,0 +1,52 @@ +use mqrstt::{ + packets::{self, Packet}, + AsyncEventHandler, ConnectOptions, MqttClient, NetworkBuilder, NetworkStatus, +}; +pub struct PingPong { + pub client: MqttClient, +} +impl AsyncEventHandler for PingPong { + // Handlers only get INCOMING packets. This can change later. + async fn handle(&mut self, event: packets::Packet) { + match event { + Packet::Publish(p) => { + if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { + if payload.to_lowercase().contains("ping") { + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); + println!("Received Ping, Send pong!"); + } + } + } + Packet::ConnAck(_) => { + println!("Connected!") + } + _ => (), + } + } +} +fn main() { + smol::block_on(async { + let (mut network, client) = NetworkBuilder::new_from_client_id("mqrsttSmolExample").smol_network(); + let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + + let mut pingpong = PingPong { client: client.clone() }; + + network.connect(stream, &mut pingpong).await.unwrap(); + + // This subscribe is only processed when we run the network + client.subscribe("mqrstt").await.unwrap(); + + let task_handle = smol::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + smol::Timer::after(std::time::Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = task_handle.await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); + }); +} diff --git a/examples/tcp/src/tokio.rs b/examples/tcp/src/tokio.rs index 1e9693b..e3db001 100644 --- a/examples/tcp/src/tokio.rs +++ b/examples/tcp/src/tokio.rs @@ -21,6 +21,7 @@ async fn main() { let mut handler = Handler { byte_count: 0 }; let stream = tokio::net::TcpStream::connect(hostname).await.unwrap(); + let stream = tokio::io::BufStream::new(stream); let (mut network, client) = mqrstt::NetworkBuilder::new_from_client_id("TestClientABCDEFG").tokio_network(); network.connect(stream, &mut handler).await.unwrap(); diff --git a/README.md b/mqrstt/README.md similarity index 61% rename from README.md rename to mqrstt/README.md index 7e112dc..8a7fb01 100644 --- a/README.md +++ b/mqrstt/README.md @@ -24,9 +24,8 @@ For a sync approach the stream has to implement the [`std::io::Read`] and [`std: - Keep alive depends on actual communication ### To do -- no_std (Requires a lot of work to use no heap allocations and depend on stack) - Even More testing -- More documentation +- Add TLS examples to repository ## MSRV From 0.3 the tokio and smol variants will require MSRV: 1.75 due to async fn in trait feature. @@ -38,119 +37,90 @@ From 0.3 the tokio and smol variants will require MSRV: 1.75 due to async fn in - Create a new connection when an error or disconnect is encountered - Handlers only get incoming packets -### TLS: -TLS examples are too larger for a README. [TLS examples](https://github.com/GunnarMorrigan/mqrstt/tree/main/examples). ### Smol example: ```rust use mqrstt::{ - MqttClient, - ConnectOptions, - new_smol, packets::{self, Packet}, - AsyncEventHandler, - smol::NetworkStatus, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, }; -use bytes::Bytes; pub struct PingPong { pub client: MqttClient, } impl AsyncEventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet { + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { if payload.to_lowercase().contains("ping") { - self.client - .publish( - p.topic.clone(), - p.qos, - p.retain, - Bytes::from_static(b"pong"), - ) - .await - .unwrap(); + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); println!("Received Ping, Send pong!"); } } - }, - Packet::ConnAck(_) => { println!("Connected!") }, + } + Packet::ConnAck(_) => { + println!("Connected!") + } _ => (), } } } -smol::block_on(async { - let options = ConnectOptions::new("mqrsttSmolExample"); - let (mut network, client) = new_smol(options); - let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)) - .await - .unwrap(); - - let mut pingpong = PingPong { - client: client.clone(), - }; +fn main() { + smol::block_on(async { + let (mut network, client) = NetworkBuilder::new_from_client_id("mqrsttSmolExample").smol_network(); + let stream = smol::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); - network.connect(stream, &mut pingpong).await.unwrap(); + let mut pingpong = PingPong { client: client.clone() }; - // This subscribe is only processed when we run the network - client.subscribe("mqrstt").await.unwrap(); + network.connect(stream, &mut pingpong).await.unwrap(); + + // This subscribe is only processed when we run the network + client.subscribe("mqrstt").await.unwrap(); + + let task_handle = smol::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + smol::Timer::after(std::time::Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = task_handle.await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); + }); +} - let (n, t) = futures::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - smol::Timer::after(std::time::Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); -}); ``` ### Tokio example: ```rust use mqrstt::{ - MqttClient, - ConnectOptions, - new_tokio, packets::{self, Packet}, - AsyncEventHandler, - tokio::NetworkStatus, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, }; use tokio::time::Duration; -use bytes::Bytes; pub struct PingPong { pub client: MqttClient, } impl AsyncEventHandler for PingPong { - // Handlers only get INCOMING packets. This can change later. + // Handlers only get INCOMING packets. async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { if payload.to_lowercase().contains("ping") { - self.client - .publish( - p.topic.clone(), - p.qos, - p.retain, - Bytes::from_static(b"pong"), - ) - .await - .unwrap(); + self.client.publish(p.topic.clone(), p.qos, p.retain, b"pong").await.unwrap(); println!("Received Ping, Send pong!"); } } - }, - Packet::ConnAck(_) => { println!("Connected!") }, + } + Packet::ConnAck(_) => { + println!("Connected!") + } _ => (), } } @@ -158,39 +128,30 @@ impl AsyncEventHandler for PingPong { #[tokio::main] async fn main() { - let options = ConnectOptions::new("TokioTcpPingPongExample"); - - let (mut network, client) = new_tokio(options); - - let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)) - .await - .unwrap(); - - let mut pingpong = PingPong { - client: client.clone(), - }; - + let (mut network, client) = NetworkBuilder::new_from_client_id("TokioTcpPingPongExample").tokio_network(); + + let stream = tokio::net::TcpStream::connect(("broker.emqx.io", 1883)).await.unwrap(); + let stream = tokio::io::BufStream::new(stream); + + let mut pingpong = PingPong { client: client.clone() }; + network.connect(stream, &mut pingpong).await.unwrap(); - + client.subscribe("mqrstt").await.unwrap(); - - - let (n, _) = tokio::join!( - async { - loop { - return match network.poll(&mut pingpong).await { - Ok(NetworkStatus::Active) => continue, - otherwise => otherwise, - }; - } - }, - async { - tokio::time::sleep(Duration::from_secs(30)).await; - client.disconnect().await.unwrap(); - } - ); - assert!(n.is_ok()); + + let network_handle = tokio::spawn(async move { + let result = network.run(&mut pingpong).await; + (result, pingpong) + }); + + tokio::time::sleep(Duration::from_secs(30)).await; + client.disconnect().await.unwrap(); + + let (result, _pingpong) = network_handle.await.unwrap(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), NetworkStatus::OutgoingDisconnect); } + ``` ### Sync example: @@ -284,7 +245,6 @@ Licensed under * Mozilla Public License, Version 2.0, [(MPL-2.0)](https://choosealicense.com/licenses/mpl-2.0/) ## Contribution - Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, shall be licensed under MPL-2.0, without any additional terms or conditions. diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index 68ede44..c6f120a 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -20,7 +20,7 @@ pub enum ReadError { IoError(#[from] std::io::Error), } -#[derive(Error, Clone, Debug)] +#[derive(Error, Clone, Debug, PartialEq, Eq)] pub enum DeserializeError { #[error("Malformed packet: {0}")] MalformedPacketWithInfo(String), diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index bdfb57b..6b691cb 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -55,6 +55,8 @@ where S: tokio::io::AsyncReadExt + tokio::io::AsyncWriteExt + Sized + Unpin + Send + 'static, { /// Initializes an MQTT connection with the provided configuration an stream + /// + /// It is recommended to use a buffered stream. [`tokio::io::BufStream`] could be used to easily buffer both read and write. pub async fn connect(&mut self, mut stream: S, handler: &mut H) -> Result<(), ConnectionError> { let conn_ack = stream.connect(&self.options).await?; self.last_network_action = Instant::now(); From 25369307891b245048d091cf8fef98706a90b35b Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 10:26:25 +0100 Subject: [PATCH 41/45] add entry in readme about fuzzing --- mqrstt/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mqrstt/README.md b/mqrstt/README.md index 8a7fb01..aa84c8e 100644 --- a/mqrstt/README.md +++ b/mqrstt/README.md @@ -22,6 +22,9 @@ For a sync approach the stream has to implement the [`std::io::Read`] and [`std: - TLS/TCP - Lean - Keep alive depends on actual communication +- This tokio implemention has been fuzzed using cargo-fuzz! + + ### To do - Even More testing From d96cb78ce9a9ad4f0c97d7005240a76ae2a22a35 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 10:29:59 +0100 Subject: [PATCH 42/45] remove AsyncWriteExt mentions for smol too --- examples/tcp/src/ping_pong_smol.rs | 2 +- mqrstt/README.md | 3 +-- mqrstt/src/packets/error.rs | 2 +- mqrstt/src/smol/network.rs | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/tcp/src/ping_pong_smol.rs b/examples/tcp/src/ping_pong_smol.rs index 2647daa..9af87bd 100644 --- a/examples/tcp/src/ping_pong_smol.rs +++ b/examples/tcp/src/ping_pong_smol.rs @@ -1,6 +1,6 @@ use mqrstt::{ packets::{self, Packet}, - AsyncEventHandler, ConnectOptions, MqttClient, NetworkBuilder, NetworkStatus, + AsyncEventHandler, MqttClient, NetworkBuilder, NetworkStatus, }; pub struct PingPong { pub client: MqttClient, diff --git a/mqrstt/README.md b/mqrstt/README.md index aa84c8e..a331db9 100644 --- a/mqrstt/README.md +++ b/mqrstt/README.md @@ -10,8 +10,7 @@ `MQRSTT` is an MQTTv5 client that provides sync and async (smol and tokio) implementation. Because this crate aims to be runtime agnostic the user is required to provide their own data stream. -For an async approach the stream has to implement the smol or tokio [`AsyncReadExt`] and [`AsyncWriteExt`] traits. -For a sync approach the stream has to implement the [`std::io::Read`] and [`std::io::Write`] traits. +The stream has to implement the smol or tokio [`AsyncReadExt`] and [`AsyncWrite`] traits. diff --git a/mqrstt/src/packets/error.rs b/mqrstt/src/packets/error.rs index c6f120a..68ede44 100644 --- a/mqrstt/src/packets/error.rs +++ b/mqrstt/src/packets/error.rs @@ -20,7 +20,7 @@ pub enum ReadError { IoError(#[from] std::io::Error), } -#[derive(Error, Clone, Debug, PartialEq, Eq)] +#[derive(Error, Clone, Debug)] pub enum DeserializeError { #[error("Malformed packet: {0}")] MalformedPacketWithInfo(String), diff --git a/mqrstt/src/smol/network.rs b/mqrstt/src/smol/network.rs index 0d96c29..3eed906 100644 --- a/mqrstt/src/smol/network.rs +++ b/mqrstt/src/smol/network.rs @@ -15,7 +15,7 @@ use crate::{AsyncEventHandler, StateHandler}; use super::stream::Stream; -/// [`Network`] reads and writes to the network based on tokios [`::smol::io::AsyncReadExt`] [`::smol::io::AsyncWriteExt`]. +/// [`Network`] reads and writes to the network based on tokios [`::smol::io::AsyncRead`] [`::smol::io::AsyncWrite`]. /// This way you can provide the `connect` function with a TLS and TCP stream of your choosing. /// The most import thing to remember is that you have to provide a new stream after the previous has failed. /// (i.e. you need to reconnect after any expected or unexpected disconnect). @@ -62,7 +62,7 @@ impl Network { impl Network where H: AsyncEventHandler, - S: smol::io::AsyncReadExt + smol::io::AsyncWriteExt + Sized + Unpin, + S: smol::io::AsyncRead + smol::io::AsyncWrite + Sized + Unpin, { /// Initializes an MQTT connection with the provided configuration an stream pub async fn connect(&mut self, stream: S, handler: &mut H) -> Result<(), ConnectionError> { From 25a59de19b0c0d8f1cff871617906d6431f46231 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 10:32:31 +0100 Subject: [PATCH 43/45] remove first_byte function --- mqrstt/src/packets/mod.rs | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 7fd4a0a..8af3522 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -83,36 +83,6 @@ impl Packet { } } - pub(crate) fn first_byte(&self) -> u8 { - match self { - Packet::Connect(_) => 0b0001_0000, - Packet::ConnAck(_) => 0b0010_0000, - Packet::Publish(p) => { - let mut first_byte = 0b0011_0000u8; - if p.dup { - first_byte |= 0b1000; - } - first_byte |= p.qos.into_u8() << 1; - if p.retain { - first_byte |= 0b0001; - } - first_byte - } - Packet::PubAck(_) => 0b0100_0000, - Packet::PubRec(_) => 0b0101_0000, - Packet::PubRel(_) => 0b0110_0010, - Packet::PubComp(_) => 0b0111_0000, - Packet::Subscribe(_) => 0b1000_0010, - Packet::SubAck(_) => 0b1001_0000, - Packet::Unsubscribe(_) => 0b1010_0010, - Packet::UnsubAck(_) => 0b1011_0000, - Packet::PingReq => 0b1100_0000, - Packet::PingResp => 0b1101_0000, - Packet::Disconnect(_) => 0b1110_0000, - Packet::Auth(_) => 0b1111_0000, - } - } - pub(crate) fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { match self { Packet::Connect(p) => { From 8d5cb02e6b6c86b07d85184deb5752152f00e93b Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:24:09 +0100 Subject: [PATCH 44/45] more test cases and fix in pub rel, comp and rec --- fuzz/fuzz_targets/fuzz_target_1.rs | 1 + mqrstt/Cargo.toml | 2 + .../src/packets/macros/properties_macros.rs | 6 +- mqrstt/src/packets/mod.rs | 96 +++++++++++-------- mqrstt/src/packets/pubcomp/mod.rs | 4 +- mqrstt/src/packets/pubrec/mod.rs | 5 +- mqrstt/src/packets/pubrel/mod.rs | 4 +- mqrstt/src/tests/test_packets.rs | 88 ++++++++++++++++- 8 files changed, 156 insertions(+), 50 deletions(-) diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs index 9f16727..84adcaa 100644 --- a/fuzz/fuzz_targets/fuzz_target_1.rs +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -7,6 +7,7 @@ async fn test(mut data: &[u8]) { let _ = mqrstt::packets::Packet::async_read(&mut data).await; } +#[cfg(target_os = "linux")] fuzz_target!(|data: &[u8]| { test(data); }); diff --git a/mqrstt/Cargo.toml b/mqrstt/Cargo.toml index 31d00c4..0b7d79a 100644 --- a/mqrstt/Cargo.toml +++ b/mqrstt/Cargo.toml @@ -52,6 +52,8 @@ tokio = { version = "1", features = [ smol = { version = "2", optional = true } [dev-dependencies] +pretty_assertions = "1.4.1" + tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } smol = { version = "2.0.0" } diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index b7b8438..759c4b2 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -399,17 +399,17 @@ macro_rules! properties_read_match_branch_body { $properties.authentication_data = Some(prop_body); }}; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestResponseInformation) => {{ - if $properties.authentication_data.is_some() { + if $properties.request_response_information.is_some() { return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( PropertyType::RequestResponseInformation, ))); } let (prop_body, read_bytes) = u8::async_read($stream).await?; $read_property_bytes += read_bytes; - $properties.request_problem_information = Some(prop_body); + $properties.request_response_information = Some(prop_body); }}; ($stream:ident, $properties:ident, $read_property_bytes:ident, PropertyType::RequestProblemInformation) => {{ - if $properties.authentication_data.is_some() { + if $properties.request_problem_information.is_some() { return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::DuplicateProperty( PropertyType::RequestProblemInformation, ))); diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 8af3522..5921c21 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -453,6 +453,7 @@ mod tests { use crate::tests::test_packets::*; #[rstest::rstest] + #[case::connect_case(connect_case())] #[case::ping_req_case(ping_req_case().1)] #[case::ping_resp_case(ping_resp_case().1)] #[case::connack_case(connack_case().1)] @@ -472,6 +473,10 @@ mod tests { #[case::suback(suback_case())] #[case::unsubscribe(unsubscribe_case())] #[case::unsuback(unsuback_case())] + #[case::pubcomp_case(pubcomp_case())] + #[case::pubrec_case(pubrec_case())] + #[case::pubrec_case(pubrel_case2())] + #[case::auth_case(auth_case())] fn test_write_read_write_read_cases(#[case] packet: Packet) { use crate::packets::WireLength; @@ -481,9 +486,11 @@ mod tests { let wire_len = packet.wire_len(); assert_eq!(wire_len, buffer.len()); - dbg!(wire_len); - let a: Vec<_> = buffer.iter().map(|f| *f as u16).collect(); - println!("{:?}", a); + + // dbg!(wire_len); + + // let a: Vec<_> = buffer.iter().map(|f| *f as u16).collect(); + // println!("{:?}", a); let res1 = Packet::read(&mut buffer).unwrap(); @@ -496,6 +503,50 @@ mod tests { assert_eq!(res1, res2); } + #[rstest::rstest] + #[case::connect_case(connect_case())] + #[case::ping_req_case(ping_req_case().1)] + #[case::ping_resp_case(ping_resp_case().1)] + #[case::connack_case(connack_case().1)] + #[case::create_subscribe_packet(create_subscribe_packet(1))] + #[case::create_subscribe_packet(create_subscribe_packet(65335))] + #[case::create_puback_packet(create_puback_packet(1))] + #[case::create_puback_packet(create_puback_packet(65335))] + #[case::create_disconnect_packet(create_disconnect_packet())] + #[case::create_connack_packet(create_connack_packet(true))] + #[case::create_connack_packet(create_connack_packet(false))] + #[case::publish_packet_1(publish_packet_1())] + #[case::publish_packet_2(publish_packet_2())] + #[case::publish_packet_3(publish_packet_3())] + #[case::publish_packet_4(publish_packet_4())] + #[case::create_empty_publish_packet(create_empty_publish_packet())] + #[case::subscribe(subscribe_case())] + #[case::suback(suback_case())] + #[case::unsubscribe(unsubscribe_case())] + #[case::unsuback(unsuback_case())] + #[case::pubcomp_case(pubcomp_case())] + #[case::pubrec_case(pubrec_case())] + #[case::pubrec_case(pubrel_case2())] + #[case::auth_case(auth_case())] + #[tokio::test] + async fn test_async_write_read_write_read_cases(#[case] packet: Packet) { + use crate::packets::WireLength; + + let mut buffer = Vec::with_capacity(1000); + let res = packet.async_write(&mut buffer).await.unwrap(); + + let wire_len = packet.wire_len(); + + assert_eq!(res, buffer.len()); + assert_eq!(wire_len, buffer.len()); + + let mut buf = buffer.as_slice(); + + let res1 = Packet::async_read(&mut buf).await.unwrap(); + + pretty_assertions::assert_eq!(packet, res1); + } + #[rstest::rstest] #[case::disconnect(disconnect_case())] #[case::ping_req(ping_req_case())] @@ -548,45 +599,6 @@ mod tests { assert_eq!(out, input) } - #[rstest::rstest] - #[case::ping_req_case(ping_req_case().1)] - #[case::ping_resp_case(ping_resp_case().1)] - #[case::connack_case(connack_case().1)] - #[case::create_subscribe_packet(create_subscribe_packet(1))] - #[case::create_subscribe_packet(create_subscribe_packet(65335))] - #[case::create_puback_packet(create_puback_packet(1))] - #[case::create_puback_packet(create_puback_packet(65335))] - #[case::create_disconnect_packet(create_disconnect_packet())] - #[case::create_connack_packet(create_connack_packet(true))] - #[case::create_connack_packet(create_connack_packet(false))] - #[case::publish_packet_1(publish_packet_1())] - #[case::publish_packet_2(publish_packet_2())] - #[case::publish_packet_3(publish_packet_3())] - #[case::publish_packet_4(publish_packet_4())] - #[case::create_empty_publish_packet(create_empty_publish_packet())] - #[case::subscribe(subscribe_case())] - #[case::suback(suback_case())] - #[case::unsubscribe(unsubscribe_case())] - #[case::unsuback(unsuback_case())] - #[tokio::test] - async fn test_async_write_read_write_read_cases(#[case] packet: Packet) { - use crate::packets::WireLength; - - let mut buffer = Vec::with_capacity(1000); - let res = packet.async_write(&mut buffer).await.unwrap(); - - let wire_len = packet.wire_len(); - - assert_eq!(res, buffer.len()); - assert_eq!(wire_len, buffer.len()); - - let mut buf = buffer.as_slice(); - - let res1 = Packet::async_read(&mut buf).await.unwrap(); - - assert_eq!(packet, res1); - } - // #[rstest::rstest] // #[case(&[59, 1, 0, 59])] // #[case(&[16, 14, 0, 4, 77, 81, 84, 84, 5, 247, 247, 252, 1, 17, 247, 247, 247])] diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 459b036..dce9b43 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -7,6 +7,7 @@ pub use properties::PubCompProperties; use super::{ error::{DeserializeError, ReadError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, }; use bytes::BufMut; use tokio::io::AsyncReadExt; @@ -148,7 +149,8 @@ impl WireLength for PubComp { } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 3 } else { - 2 + 1 + self.properties.wire_len() + let prop_wire_len = self.properties.wire_len(); + 2 + 1 + prop_wire_len.variable_integer_len() + prop_wire_len } } } diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 92ca8be..075ee04 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -11,7 +11,7 @@ use tokio::io::AsyncReadExt; use super::{ error::DeserializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketRead, PacketWrite, WireLength}, - PacketAsyncRead, + PacketAsyncRead, VariableInteger, }; /// The [`PubRec`] (Publish Received) packet is part of the acknowledgment flow for a [`crate::packets::Publish`] with QoS 2. @@ -141,7 +141,8 @@ impl WireLength for PubRec { } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 3 } else { - 2 + 1 + self.properties.wire_len() + let prop_wire_len = self.properties.wire_len(); + 2 + 1 + prop_wire_len.variable_integer_len() + prop_wire_len } } } diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index 03d78d9..c714739 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -10,6 +10,7 @@ use tokio::io::AsyncReadExt; use super::{ error::{DeserializeError, ReadError}, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite, WireLength}, + VariableInteger, }; /// The [`PubRel`] (Publish Release) packet acknowledges the reception of a [`crate::packets::PubRec`] Packet. @@ -139,7 +140,8 @@ impl WireLength for PubRel { } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { 3 } else { - 2 + 1 + self.properties.wire_len() + let prop_wire_len = self.properties.wire_len(); + 2 + 1 + prop_wire_len.variable_integer_len() + prop_wire_len } } } diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index ffc8f18..f23a80a 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -119,6 +119,31 @@ pub fn pubrel_smallest_case() -> (&'static [u8], Packet) { (packet, Packet::PubRel(expected)) } +pub fn connect_case() -> Packet { + let connect = Connect { + protocol_version: ProtocolVersion::V5, + clean_start: true, + last_will: Some(LastWill::new(QoS::ExactlyOnce, true, "will/topic", b"will payload".to_vec())), + username: Some("ThisIsTheUsername".into()), + password: Some("ThisIsThePassword".into()), + keep_alive: 60, + connect_properties: ConnectProperties { + session_expiry_interval: Some(5), + receive_maximum: Some(10), + maximum_packet_size: Some(100), + topic_alias_maximum: Some(10), + user_properties: vec![("test".into(), "test".into()), ("test2".into(), "test2".into())], + authentication_method: Some("AuthenticationMethod".into()), + authentication_data: Some(b"AuthenticationData".to_vec()), + request_response_information: Some(0), + request_problem_information: Some(1), + }, + client_id: "ThisIsTheClientID".into(), + }; + + Packet::Connect(connect) +} + pub fn publish_packet_1() -> Packet { Packet::Publish(Publish { dup: false, @@ -324,12 +349,73 @@ pub fn unsubscribe_case() -> Packet { let expected = Unsubscribe { packet_identifier: 3, topics: vec!["test/topic".into()], - properties: UnsubscribeProperties { user_properties: vec![("written += 1;".into(), "value".into())] }, + properties: UnsubscribeProperties { + user_properties: vec![("written += 1;".into(), "value".into())], + }, }; Packet::Unsubscribe(expected) } +pub fn pubrec_case() -> Packet { + let expected = PubRec { + packet_identifier: 3, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties { + reason_string: Some("test".into()), + user_properties: vec![("test5asdf".into(), "test3".into()), ("test4".into(), "test2".into())], + }, + }; + + Packet::PubRec(expected) +} + +pub fn pubcomp_case() -> Packet { + let expected = PubComp { + packet_identifier: 3, + reason_code: PubCompReasonCode::PacketIdentifierNotFound, + properties: PubCompProperties { + reason_string: Some("test".into()), + user_properties: vec![ + ("test5asdf".into(), "test3".into()), + ("test⌚5asdf".into(), "test3".into()), + ("test5asdf".into(), "test3".into()), + ("test5asdf".into(), "test3".into()), + ("test4".into(), "test2".into()), + ], + }, + }; + + Packet::PubComp(expected) +} + +pub fn pubrel_case2() -> Packet { + let expected = PubRel { + packet_identifier: 3, + reason_code: PubRelReasonCode::Success, + properties: PubRelProperties { + reason_string: Some("test".into()), + user_properties: vec![("test5asdf".into(), "test3".repeat(10000).into()), ("test4".into(), "test2".into())], + }, + }; + + Packet::PubRel(expected) +} + +pub fn auth_case() -> Packet { + let expected = Auth { + reason_code: AuthReasonCode::ContinueAuthentication, + properties: AuthProperties { + authentication_method: Some("SomeRandomDataHere".into()), + authentication_data: Some(b"VeryRandomStuff".to_vec()), + reason_string: Some("⌚this_is_for_sure_a_test_⌚".into()), + user_properties: vec![("SureHopeThisWorks".into(), "😰".into())], + }, + }; + + Packet::Auth(expected) +} + #[rstest] #[case(create_subscribe_packet(1))] #[case(create_subscribe_packet(65335))] From 634e05b81f87fa767f7d0c3090bec2e3c0cdb4c1 Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:25:19 +0100 Subject: [PATCH 45/45] Disable fuzzing crate, gives errors in windows --- Cargo.toml | 2 +- fuzz/fuzz_targets/fuzz_target_1.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 10dc1aa..9b444dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,6 @@ members = [ "mqrstt", - "fuzz", + # "fuzz", "examples/tcp" ] \ No newline at end of file diff --git a/fuzz/fuzz_targets/fuzz_target_1.rs b/fuzz/fuzz_targets/fuzz_target_1.rs index 84adcaa..2dc7634 100644 --- a/fuzz/fuzz_targets/fuzz_target_1.rs +++ b/fuzz/fuzz_targets/fuzz_target_1.rs @@ -1,7 +1,9 @@ #![no_main] +#[cfg(target_os = "linux")] use libfuzzer_sys::fuzz_target; +#[cfg(target_os = "linux")] #[tokio::main(flavor = "current_thread")] async fn test(mut data: &[u8]) { let _ = mqrstt::packets::Packet::async_read(&mut data).await;