From 878744dd92f6613edb5e17ed0aad6ab8c2b528ac Mon Sep 17 00:00:00 2001 From: GunnarMorrigan <13799935+GunnarMorrigan@users.noreply.github.com> Date: Mon, 18 Dec 2023 07:58:21 +0100 Subject: [PATCH] A start to general async read trait But this might also be totally not worth it as it is out of scope. Also, it is complex --- src/lib.rs | 2 + src/mqtt_async_traits.rs | 79 +++++++++ src/packets/auth.rs | 341 +++++++++++++++++++++---------------- src/packets/error.rs | 5 +- src/packets/mod.rs | 30 +++- src/packets/mqtt_traits.rs | 4 +- 6 files changed, 309 insertions(+), 152 deletions(-) create mode 100644 src/mqtt_async_traits.rs diff --git a/src/lib.rs b/src/lib.rs index eb330ba..9fa2435 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -188,6 +188,8 @@ mod connect_options; mod mqtt_handler; mod util; +mod mqtt_async_traits; + #[cfg(feature = "smol")] pub mod smol; #[cfg(feature = "sync")] diff --git a/src/mqtt_async_traits.rs b/src/mqtt_async_traits.rs new file mode 100644 index 0000000..dc81f50 --- /dev/null +++ b/src/mqtt_async_traits.rs @@ -0,0 +1,79 @@ +use futures::Future; +use tokio::io::AsyncReadExt; + +pub trait AsyncMqttRead { + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> impl Future> + Send + Sync; + + fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> impl Future> + Send + Sync; + + fn take() + + // async fn read_exact_into_vec(&mut self, limit: u64) -> std::io::Result>; +} + +// #[cfg(feature = "a")] +// impl MqttAsyncRead for S where S: futures::AsyncReadExt + Unpin { +// fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> impl Future> { +// self.read(buf) +// } +// } + +#[cfg(feature = "tokio")] +impl AsyncMqttRead for S where S: tokio::io::AsyncReadExt + Unpin + Send + Sync { + fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> impl Future> + Send + Sync { + self.read(buf) + } + + fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> impl Future> + Send + Sync{ + async{ + self.read_exact(buf).await.map(|_| ()) + } + } + + // async fn read_exact_into_vec(&mut self, limit: u64) -> std::io::Result> { + // let mut buffer = Vec::with_capacity(limit as usize); + + // self.take(limit).read_to_end(&mut buffer).await?; + + // Ok(buffer) + // } +} + + +pub trait MqttAsyncWrite { + async fn flush(&mut self) -> Result<(), std::io::Error>; + + async fn write<'a>(&'a mut self, buf: &'a [u8]) -> std::io::Result; + + async fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> std::io::Result<()>; +} + +#[cfg(feature = "a")] +impl MqttAsyncWrite for S where S: futures::AsyncWriteExt + Unpin{ + async fn flush(&mut self) -> Result<(), std::io::Error> { + self.flush().await + } + + async fn write<'a>(&'a mut self, buf: &'a [u8]) -> std::io::Result { + self.write(buf).await + } + + async fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> std::io::Result<()> { + self.write_all(buf).await + } +} + +#[cfg(feature = "tokio")] +impl MqttAsyncWrite for S where S: tokio::io::AsyncWriteExt + Unpin{ + async fn flush(&mut self) -> Result<(), std::io::Error> { + self.flush().await + } + + async fn write<'a>(&'a mut self, buf: &'a [u8]) -> std::io::Result { + self.write(buf).await + } + + async fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> std::io::Result<()> { + self.write_all(buf).await + } +} \ No newline at end of file diff --git a/src/packets/auth.rs b/src/packets/auth.rs index cc80935..81e6989 100644 --- a/src/packets/auth.rs +++ b/src/packets/auth.rs @@ -1,149 +1,192 @@ -use bytes::Bytes; - -use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, - read_variable_integer, - reason_codes::AuthReasonCode, - variable_integer_len, write_variable_integer, PacketType, PropertyType, -}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Auth { - pub reason_code: AuthReasonCode, - pub properties: AuthProperties, -} - -impl VariableHeaderRead for Auth { - fn read(_: u8, _: usize, mut buf: Bytes) -> Result { - let reason_code = AuthReasonCode::read(&mut buf)?; - let properties = AuthProperties::read(&mut buf)?; - - Ok(Self { reason_code, properties }) - } -} - -impl VariableHeaderWrite for Auth { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - self.reason_code.write(buf)?; - self.properties.write(buf)?; - Ok(()) - } -} - -impl WireLength for Auth { - fn wire_len(&self) -> usize { - 1 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len() - } -} - -#[derive(Debug, Default, PartialEq, Eq, Clone)] -pub struct AuthProperties { - /// 3.15.2.2.2 Authentication Method - /// 21 (0x15) Byte, Identifier of the Authentication Method. - pub authentication_method: Option, - - /// 3.15.2.2.3 Authentication Data - /// 22 (0x16) Byte, Identifier of the Authentication Data - pub authentication_data: Bytes, - - /// 3.15.2.2.4 Reason String - /// 31 (0x1F) Byte, Identifier of the Reason String - pub reason_string: Option, - - /// 3.15.2.2.5 User Property - /// 38 (0x26) Byte, Identifier of the User Property. - pub user_properties: Vec<(String, String)>, -} - -impl MqttRead for AuthProperties { - fn read(buf: &mut Bytes) -> Result { - let (len, _) = read_variable_integer(buf)?; - - let mut properties = AuthProperties::default(); - - if len == 0 { - return Ok(properties); - } else if buf.len() < len { - return Err(DeserializeError::MalformedPacket); - } - - let mut property_data = buf.split_to(len); - - loop { - match PropertyType::read(&mut property_data)? { - PropertyType::ReasonString => { - if properties.reason_string.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); - } - properties.reason_string = Some(String::read(&mut property_data)?); - } - PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), - PropertyType::AuthenticationMethod => { - if properties.authentication_method.is_some() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); - } - properties.authentication_method = Some(String::read(&mut property_data)?); - } - PropertyType::AuthenticationData => { - if properties.authentication_data.is_empty() { - return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); - } - properties.authentication_data = Bytes::read(&mut property_data)?; - } - e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Auth)), - } - - if property_data.is_empty() { - break; - } - } - - Ok(properties) - } -} - -impl MqttWrite for AuthProperties { - fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { - write_variable_integer(buf, self.wire_len())?; - - if let Some(authentication_method) = &self.authentication_method { - PropertyType::AuthenticationMethod.write(buf)?; - authentication_method.write(buf)?; - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - PropertyType::AuthenticationData.write(buf)?; - self.authentication_data.write(buf)?; - } - if let Some(reason_string) = &self.reason_string { - PropertyType::ReasonString.write(buf)?; - reason_string.write(buf)?; - } - for (key, value) in &self.user_properties { - PropertyType::UserProperty.write(buf)?; - key.write(buf)?; - value.write(buf)?; - } - - Ok(()) - } -} - -impl WireLength for AuthProperties { - fn wire_len(&self) -> usize { - let mut len = 0; - if let Some(authentication_method) = &self.authentication_method { - len += authentication_method.wire_len(); - } - if !self.authentication_data.is_empty() && self.authentication_method.is_some() { - len += self.authentication_data.wire_len(); - } - if let Some(reason_string) = &self.reason_string { - len += reason_string.wire_len(); - } - for (key, value) in &self.user_properties { - len += key.wire_len() + value.wire_len(); - } - len - } -} +use bytes::Bytes; + +use super::{ + error::DeserializeError, + mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}, + read_variable_integer, + reason_codes::AuthReasonCode, + variable_integer_len, write_variable_integer, PacketType, PropertyType, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Auth { + pub reason_code: AuthReasonCode, + pub properties: AuthProperties, +} + +impl VariableHeaderRead for Auth { + fn read(_: u8, _: usize, mut buf: Bytes) -> Result { + let reason_code = AuthReasonCode::read(&mut buf)?; + let properties = AuthProperties::read(&mut buf)?; + + Ok(Self { reason_code, properties }) + } +} + +impl VariableHeaderWrite for Auth { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + self.reason_code.write(buf)?; + self.properties.write(buf)?; + Ok(()) + } +} + +impl WireLength for Auth { + fn wire_len(&self) -> usize { + 1 + variable_integer_len(self.properties.wire_len()) + self.properties.wire_len() + } +} + +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct AuthProperties { + /// 3.15.2.2.2 Authentication Method + /// 21 (0x15) Byte, Identifier of the Authentication Method. + pub authentication_method: Option, + + /// 3.15.2.2.3 Authentication Data + /// 22 (0x16) Byte, Identifier of the Authentication Data + pub authentication_data: Bytes, + + /// 3.15.2.2.4 Reason String + /// 31 (0x1F) Byte, Identifier of the Reason String + pub reason_string: Option, + + /// 3.15.2.2.5 User Property + /// 38 (0x26) Byte, Identifier of the User Property. + pub user_properties: Vec<(String, String)>, +} + +impl MqttRead for AuthProperties{ + async fn read(stream: &mut S) -> Result { + let (len, _) = read_variable_integer(stream).await?; + + let mut properties = AuthProperties::default(); + + if len == 0 { + return Ok(properties); + } + +// loop { +// match PropertyType::read(&mut property_data)? { +// PropertyType::ReasonString => { +// if properties.reason_string.is_some() { +// return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); +// } +// properties.reason_string = Some(String::read(&mut property_data)?); +// } +// PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), +// PropertyType::AuthenticationMethod => { +// if properties.authentication_method.is_some() { +// return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); +// } +// properties.authentication_method = Some(String::read(&mut property_data)?); +// } +// PropertyType::AuthenticationData => { +// if properties.authentication_data.is_empty() { +// return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); +// } +// properties.authentication_data = Bytes::read(&mut property_data)?; +// } +// e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Auth)), +// } + +// if property_data.is_empty() { +// break; +// } +// } + +// Ok(properties) + } +} + +// impl MqttRead for AuthProperties { +// fn read(buf: &mut Bytes) -> Result { +// let (len, _) = read_variable_integer(buf)?; + +// let mut properties = AuthProperties::default(); + +// if len == 0 { +// return Ok(properties); +// } else if buf.len() < len { +// return Err(DeserializeError::MalformedPacket); +// } + +// let mut property_data = buf.split_to(len); + +// loop { +// match PropertyType::read(&mut property_data)? { +// PropertyType::ReasonString => { +// if properties.reason_string.is_some() { +// return Err(DeserializeError::DuplicateProperty(PropertyType::SessionExpiryInterval)); +// } +// properties.reason_string = Some(String::read(&mut property_data)?); +// } +// PropertyType::UserProperty => properties.user_properties.push((String::read(&mut property_data)?, String::read(&mut property_data)?)), +// PropertyType::AuthenticationMethod => { +// if properties.authentication_method.is_some() { +// return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationMethod)); +// } +// properties.authentication_method = Some(String::read(&mut property_data)?); +// } +// PropertyType::AuthenticationData => { +// if properties.authentication_data.is_empty() { +// return Err(DeserializeError::DuplicateProperty(PropertyType::AuthenticationData)); +// } +// properties.authentication_data = Bytes::read(&mut property_data)?; +// } +// e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::Auth)), +// } + +// if property_data.is_empty() { +// break; +// } +// } + +// Ok(properties) +// } +// } + +impl MqttWrite for AuthProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + write_variable_integer(buf, self.wire_len())?; + + if let Some(authentication_method) = &self.authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if !self.authentication_data.is_empty() && self.authentication_method.is_some() { + PropertyType::AuthenticationData.write(buf)?; + self.authentication_data.write(buf)?; + } + if let Some(reason_string) = &self.reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, value) in &self.user_properties { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + value.write(buf)?; + } + + Ok(()) + } +} + +impl WireLength for AuthProperties { + fn wire_len(&self) -> usize { + let mut len = 0; + if let Some(authentication_method) = &self.authentication_method { + len += authentication_method.wire_len(); + } + if !self.authentication_data.is_empty() && self.authentication_method.is_some() { + len += self.authentication_data.wire_len(); + } + if let Some(reason_string) = &self.reason_string { + len += reason_string.wire_len(); + } + for (key, value) in &self.user_properties { + len += key.wire_len() + value.wire_len(); + } + len + } +} diff --git a/src/packets/error.rs b/src/packets/error.rs index c54689c..db47baa 100644 --- a/src/packets/error.rs +++ b/src/packets/error.rs @@ -4,7 +4,7 @@ use thiserror::Error; use super::{PacketType, PropertyType}; -#[derive(Error, Clone, Debug)] +#[derive(Error, Debug)] pub enum DeserializeError { #[error("Malformed packet: {0}")] MalformedPacketWithInfo(String), @@ -44,6 +44,9 @@ pub enum DeserializeError { #[error("Encountered an error when reading in a UTF-8 string. {0}")] Utf8Error(FromUtf8Error), + + #[error("IO error")] + IoError(#[from] std::io::Error) } impl From for DeserializeError { diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 5f160e4..df4d144 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -34,6 +34,8 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use core::slice::Iter; use std::fmt::Display; +use crate::mqtt_async_traits::AsyncMqttRead; + use self::error::{DeserializeError, ReadBytes, SerializeError}; use self::mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, VariableHeaderWrite, WireLength}; @@ -262,7 +264,33 @@ pub fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), Re Err(ReadBytes::Err(DeserializeError::MalformedPacket)) } -pub fn read_variable_integer(buf: &mut Bytes) -> Result<(usize, usize), DeserializeError> { +pub async fn read_variable_integer(stream: &mut S) -> Result<(usize, usize), DeserializeError> { + let mut integer = 0; + let mut length = 0; + + let mut byte: u8 = 0; + + for i in 0..4 { + // if buf.is_empty() { + // return Err(DeserializeError::MalformedPacket); + // } + + stream.read_exact(std::slice::from_mut(&mut byte)).await?; + + length += 1; + + + integer += (byte as usize & 0x7f) << (7 * i); + + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); + } + } + Err(DeserializeError::MalformedPacket) +} + + +pub fn read_variable_integer_old(buf: &mut Bytes) -> Result<(usize, usize), DeserializeError> { let mut integer = 0; let mut length = 0; diff --git a/src/packets/mqtt_traits.rs b/src/packets/mqtt_traits.rs index cd20f0f..72a0b53 100644 --- a/src/packets/mqtt_traits.rs +++ b/src/packets/mqtt_traits.rs @@ -1,5 +1,7 @@ use bytes::{Bytes, BytesMut}; +use crate::mqtt_async_traits::AsyncMqttRead; + use super::error::{DeserializeError, SerializeError}; pub trait VariableHeaderRead: Sized { @@ -15,7 +17,7 @@ pub trait WireLength { } pub trait MqttRead: Sized { - fn read(buf: &mut Bytes) -> Result; + async fn read(stream: &mut S) -> Result; } pub trait MqttWrite: Sized {