diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 4ce4a31..67e83c0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -34,7 +34,7 @@ jobs: # run clippy to verify we have no warnings - run: cargo fetch - name: cargo clippy - run: cargo clippy --all-targets --all-features -- -D warnings + run: cargo clippy -p mqrstt test: name: Test diff --git a/README.md b/README.md index 58dff00..7e112dc 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ pub struct PingPong { } impl AsyncEventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { @@ -132,7 +132,7 @@ pub struct PingPong { } impl AsyncEventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { @@ -212,7 +212,7 @@ pub struct PingPong { impl EventHandler for PingPong { // Handlers only get INCOMING packets. This can change later. - fn handle(&mut self, event: packets::Packet) -> () { + fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { diff --git a/mqrstt/src/error.rs b/mqrstt/src/error.rs index 7651ceb..b439912 100644 --- a/mqrstt/src/error.rs +++ b/mqrstt/src/error.rs @@ -69,8 +69,8 @@ pub enum HandlerError { #[error("The incoming channel between network and handler is closed")] IncomingNetworkChannelClosed, - #[error("The outgoing channel between handler and network is closed: {0}")] - OutgoingNetworkChannelClosed(#[from] SendError), + #[error("The outgoing channel between handler and network is closed")] + OutgoingNetworkChannelClosed, #[error("Channel between client and handler closed")] ClientChannelClosed, @@ -88,6 +88,12 @@ pub enum HandlerError { UnexpectedPacket(PacketType), } +impl From> for HandlerError { + fn from(_: SendError) -> Self { + HandlerError::OutgoingNetworkChannelClosed + } +} + /// Errors producable by the [`crate::MqttClient`] #[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] pub enum ClientError { diff --git a/mqrstt/src/event_handlers.rs b/mqrstt/src/event_handlers.rs index 555d04f..b69fca6 100644 --- a/mqrstt/src/event_handlers.rs +++ b/mqrstt/src/event_handlers.rs @@ -13,9 +13,7 @@ pub trait AsyncEventHandler { /// This is a simple no operation handler. impl AsyncEventHandler for () { - fn handle(&mut self, _: Packet) -> impl Future + Send + Sync { - async {} - } + async fn handle(&mut self, _: Packet) {} } pub trait EventHandler { @@ -60,7 +58,7 @@ pub mod example_handlers { } impl AsyncEventHandler for PingResp { - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { use Packet::*; if event == PingResp { self.ping_resp_received += 1; @@ -91,7 +89,7 @@ pub mod example_handlers { } impl AsyncEventHandler for PingPong { - async fn handle(&mut self, event: packets::Packet) -> () { + async fn handle(&mut self, event: packets::Packet) { match event { Packet::Publish(p) => { if let Ok(payload) = String::from_utf8(p.payload.to_vec()) { diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index c131ec6..e4d0477 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -123,7 +123,7 @@ pub mod smol; /// Contains the reader and writer parts for the tokio runtime. /// /// Module [`crate::tokio`] contains both a synchronized and concurrent approach to call the users `Handler`. -#[cfg(any(feature = "tokio"))] +#[cfg(feature = "tokio")] pub mod tokio; /// Error types that the user can see during operation of the client. @@ -342,7 +342,7 @@ mod smol_lib_test { }); } - #[cfg(all(target_family = "windows"))] + #[cfg(target_family = "windows")] #[test] fn test_close_write_tcp_stream_smol() { use crate::error::ConnectionError; @@ -409,7 +409,8 @@ mod tokio_lib_test { network.connect(stream, &mut pingresp).await.unwrap(); let network_handle = tokio::task::spawn(async move { - network.run(&mut pingresp).await; + let result = network.run(&mut pingresp).await; + // check result and or restart the connection pingresp }); diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs index 74a2107..e16fdda 100644 --- a/mqrstt/src/packets/auth/mod.rs +++ b/mqrstt/src/packets/auth/mod.rs @@ -1,5 +1,4 @@ mod properties; -use std::future::Future; pub use properties::AuthProperties; mod reason_code; @@ -48,12 +47,10 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for Auth where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl Future> { - async move { - let reason_code_writen = self.reason_code.async_write(stream).await?; - let properties_writen = self.properties.async_write(stream).await?; - Ok(reason_code_writen + properties_writen) - } + async fn async_write(&self, stream: &mut S) -> Result { + let reason_code_writen = self.reason_code.async_write(stream).await?; + let properties_writen = self.properties.async_write(stream).await?; + Ok(reason_code_writen + properties_writen) } } diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 2ff8650..94da71f 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -53,21 +53,19 @@ impl PacketAsyncRead for ConnAck where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, _: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; - let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; - let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; - - Ok(( - Self { - connack_flags, - reason_code, - connack_properties, - }, - read_bytes + reason_code_read_bytes + connack_properties_read_bytes, - )) - } + async fn async_read(_: u8, _: usize, stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let (connack_flags, read_bytes) = ConnAckFlags::async_read(stream).await?; + let (reason_code, reason_code_read_bytes) = ConnAckReasonCode::async_read(stream).await?; + let (connack_properties, connack_properties_read_bytes) = ConnAckProperties::async_read(stream).await?; + + Ok(( + Self { + connack_flags, + reason_code, + connack_properties, + }, + read_bytes + reason_code_read_bytes + connack_properties_read_bytes, + )) } } @@ -85,15 +83,13 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for ConnAck where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - use crate::packets::mqtt_trait::MqttAsyncWrite; - let connack_flags_writen = self.connack_flags.async_write(stream).await?; - let reason_code_writen = self.reason_code.async_write(stream).await?; - let connack_properties_writen = self.connack_properties.async_write(stream).await?; - - Ok(connack_flags_writen + reason_code_writen + connack_properties_writen) - } + async fn async_write(&self, stream: &mut S) -> Result { + use crate::packets::mqtt_trait::MqttAsyncWrite; + let connack_flags_writen = self.connack_flags.async_write(stream).await?; + let reason_code_writen = self.reason_code.async_write(stream).await?; + let connack_properties_writen = self.connack_properties.async_write(stream).await?; + + Ok(connack_flags_writen + reason_code_writen + connack_properties_writen) } } @@ -114,16 +110,14 @@ impl MqttAsyncRead for ConnAckFlags where S: tokio::io::AsyncRead + Unpin, { - fn async_read(stream: &mut S) -> impl std::future::Future> { - async move { - let byte = stream.read_u8().await?; - Ok(( - Self { - session_present: (byte & 0b00000001) == 0b00000001, - }, - 1, - )) - } + async fn async_read(stream: &mut S) -> Result<(Self, usize), super::error::ReadError> { + let byte = stream.read_u8().await?; + Ok(( + Self { + session_present: (byte & 0b00000001) == 0b00000001, + }, + 1, + )) } } @@ -154,14 +148,12 @@ impl crate::packets::mqtt_trait::MqttAsyncWrite for ConnAckFlags where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - use tokio::io::AsyncWriteExt; - let byte = self.session_present as u8; + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let byte = self.session_present as u8; - stream.write_u8(byte).await?; - Ok(1) - } + stream.write_u8(byte).await?; + Ok(1) } } diff --git a/mqrstt/src/packets/connect/connect_flags.rs b/mqrstt/src/packets/connect/connect_flags.rs index c2a576e..129a132 100644 --- a/mqrstt/src/packets/connect/connect_flags.rs +++ b/mqrstt/src/packets/connect/connect_flags.rs @@ -77,11 +77,9 @@ impl MqttAsyncRead for ConnectFlags where S: tokio::io::AsyncRead + Unpin, { - fn async_read(stream: &mut S) -> impl std::future::Future> { - async move { - let byte = stream.read_u8().await?; - Ok((ConnectFlags::from_u8(byte)?, 1)) - } + async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let byte = stream.read_u8().await?; + Ok((ConnectFlags::from_u8(byte)?, 1)) } } @@ -96,13 +94,11 @@ impl MqttAsyncWrite for ConnectFlags where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - use tokio::io::AsyncWriteExt; - let byte = self.into_u8()?; - stream.write_u8(byte).await?; + async fn async_write(&self, stream: &mut S) -> Result { + use tokio::io::AsyncWriteExt; + let byte = self.into_u8()?; + stream.write_u8(byte).await?; - Ok(1) - } + Ok(1) } } diff --git a/mqrstt/src/packets/connect/last_will.rs b/mqrstt/src/packets/connect/last_will.rs index b93e88e..bc076bc 100644 --- a/mqrstt/src/packets/connect/last_will.rs +++ b/mqrstt/src/packets/connect/last_will.rs @@ -82,14 +82,12 @@ impl MqttAsyncWrite for LastWill where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let properties_written = self.last_will_properties.async_write(stream).await?; - let topic_written = self.topic.async_write(stream).await?; - let payload_written = self.payload.async_write(stream).await?; + async fn async_write(&self, stream: &mut S) -> Result { + let properties_written = self.last_will_properties.async_write(stream).await?; + let topic_written = self.topic.async_write(stream).await?; + let payload_written = self.payload.async_write(stream).await?; - Ok(properties_written + topic_written + payload_written) - } + Ok(properties_written + topic_written + payload_written) } } diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index d845d9f..b71e149 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -34,40 +34,14 @@ macro_rules! define_properties { } impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + Unpin { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let mut bytes_writen = 0; - $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; - $( - $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); - )* - - Ok(bytes_writen) - } - - // let (len, length_variable_integer) = ::read_async_variable_integer(stream).await?; - // if len == 0 { - // return Ok((Self::default(), length_variable_integer)); - // } - - // let mut properties = $name::default(); - - // let mut read_property_bytes = 0; - // loop { - // let (prop, read_bytes) = crate::packets::PropertyType::async_read(stream).await?; - // read_property_bytes += read_bytes; - // match prop { - // $( - // $crate::packets::macros::properties_read_match_branch_name!($prop_variant) => $crate::packets::macros::properties_read_match_branch_body!(stream, properties, read_property_bytes, PropertyType::$prop_variant), - // )* - // e => return Err($crate::packets::error::ReadError::DeserializeError(DeserializeError::UnexpectedProperty(e, PacketType::PubRel))), - // } - // if read_property_bytes == len { - // break; - // } - // } + async fn async_write(&self, stream: &mut S) -> Result { + let mut bytes_writen = 0; + $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; + $( + $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); + )* - // Ok((properties, length_variable_integer + read_property_bytes)) + Ok(bytes_writen) } } diff --git a/mqrstt/src/packets/macros/reason_code_macros.rs b/mqrstt/src/packets/macros/reason_code_macros.rs index 8e1fa29..68ec2e4 100644 --- a/mqrstt/src/packets/macros/reason_code_macros.rs +++ b/mqrstt/src/packets/macros/reason_code_macros.rs @@ -14,7 +14,7 @@ macro_rules! reason_code { } -> ()) } - pub(crate) fn to_u8(&self) -> u8 { + pub(crate) fn to_u8(self) -> u8 { $crate::packets::macros::reason_code_match_write!(@ $name, self, { $($code,)* } -> ()) diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 73a477f..2eb135a 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -347,7 +347,7 @@ impl Packet { #[cfg(feature = "logs")] tracing::trace!("Read packet header: {:?}", header); - Ok(Packet::async_read_packet(header, stream).await?) + Packet::async_read_packet(header, stream).await } pub fn read(buffer: &mut BytesMut) -> Result> { diff --git a/mqrstt/src/packets/mqtt_trait/mod.rs b/mqrstt/src/packets/mqtt_trait/mod.rs index 50ce760..f666753 100644 --- a/mqrstt/src/packets/mqtt_trait/mod.rs +++ b/mqrstt/src/packets/mqtt_trait/mod.rs @@ -43,7 +43,7 @@ pub trait MqttWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), crate::packets::error::SerializeError>; } -impl<'a, T> MqttWrite for &'a T +impl MqttWrite for &T where T: MqttWrite, { diff --git a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs index 8ff217a..b0612cb 100644 --- a/mqrstt/src/packets/mqtt_trait/primitive_impl.rs +++ b/mqrstt/src/packets/mqtt_trait/primitive_impl.rs @@ -43,13 +43,11 @@ impl MqttAsyncWrite for Box where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_bytes()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) } } @@ -73,13 +71,11 @@ impl MqttAsyncWrite for &str where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_bytes()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) } } @@ -131,13 +127,11 @@ impl MqttAsyncWrite for String where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_bytes()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_bytes()).await?; + Ok(2 + self.len()) } } @@ -189,13 +183,11 @@ impl MqttAsyncWrite for Bytes where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self.as_ref()).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self.as_ref()).await?; + Ok(2 + self.len()) } } @@ -234,13 +226,11 @@ impl MqttAsyncWrite for Vec where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let size = (self.len() as u16).to_be_bytes(); - stream.write_all(&size).await?; - stream.write_all(self).await?; - Ok(2 + self.len()) - } + async fn async_write(&self, stream: &mut S) -> Result { + let size = (self.len() as u16).to_be_bytes(); + stream.write_all(&size).await?; + stream.write_all(self).await?; + Ok(2 + self.len()) } } impl WireLength for Vec { @@ -304,15 +294,13 @@ impl MqttAsyncWrite for bool where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - if *self { - stream.write_all(&[1]).await?; - } else { - stream.write_all(&[0]).await?; - } - Ok(1) + async fn async_write(&self, stream: &mut S) -> Result { + if *self { + stream.write_all(&[1]).await?; + } else { + stream.write_all(&[0]).await?; } + Ok(1) } } impl MqttRead for u8 { @@ -336,11 +324,9 @@ impl MqttAsyncWrite for u8 where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - stream.write_all(self.to_be_bytes().as_slice()).await?; - Ok(1) - } + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(1) } } @@ -372,11 +358,9 @@ impl MqttAsyncWrite for u16 where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - stream.write_all(self.to_be_bytes().as_slice()).await?; - Ok(2) - } + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(2) } } @@ -407,10 +391,8 @@ impl MqttAsyncWrite for u32 where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - stream.write_all(self.to_be_bytes().as_slice()).await?; - Ok(4) - } + async fn async_write(&self, stream: &mut S) -> Result { + stream.write_all(self.to_be_bytes().as_slice()).await?; + Ok(4) } } diff --git a/mqrstt/src/packets/primitive/property_type.rs b/mqrstt/src/packets/primitive/property_type.rs index cce56e1..a1ecf70 100644 --- a/mqrstt/src/packets/primitive/property_type.rs +++ b/mqrstt/src/packets/primitive/property_type.rs @@ -148,11 +148,9 @@ impl MqttAsyncWrite for PropertyType where S: tokio::io::AsyncWrite + std::marker::Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let buf: [u8; 1] = [u8::from(self)]; - stream.write_all(&buf).await?; - Ok(1) - } + async fn async_write(&self, stream: &mut S) -> Result { + let buf: [u8; 1] = [u8::from(self)]; + stream.write_all(&buf).await?; + Ok(1) } } diff --git a/mqrstt/src/packets/primitive/qos.rs b/mqrstt/src/packets/primitive/qos.rs index 572a6a5..06be58e 100644 --- a/mqrstt/src/packets/primitive/qos.rs +++ b/mqrstt/src/packets/primitive/qos.rs @@ -76,11 +76,9 @@ impl MqttAsyncWrite for QoS where S: tokio::io::AsyncWrite + std::marker::Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { - async move { - let buf: [u8; 1] = [self.into_u8()]; - stream.write_all(&buf).await?; - Ok(1) - } + async fn async_write(&self, stream: &mut S) -> Result { + let buf: [u8; 1] = [self.into_u8()]; + stream.write_all(&buf).await?; + Ok(1) } } diff --git a/mqrstt/src/packets/primitive/variable_integer.rs b/mqrstt/src/packets/primitive/variable_integer.rs index 9ee609f..1841ead 100644 --- a/mqrstt/src/packets/primitive/variable_integer.rs +++ b/mqrstt/src/packets/primitive/variable_integer.rs @@ -66,7 +66,7 @@ impl VariableInteger for usize { fn write_variable_integer(&self, buf: &mut BytesMut) -> Result { if *self > 268_435_455 { - return Err(SerializeError::VariableIntegerOverflow(*self as usize)); + return Err(SerializeError::VariableIntegerOverflow(*self)); } let mut write = *self; @@ -82,7 +82,7 @@ impl VariableInteger for usize { return Ok(i + 1); } } - Err(SerializeError::VariableIntegerOverflow(*self as usize)) + Err(SerializeError::VariableIntegerOverflow(*self)) } fn read_variable_integer(buf: &mut Bytes) -> Result<(Self, usize), DeserializeError> { @@ -105,51 +105,47 @@ impl VariableInteger for usize { Err(DeserializeError::MalformedPacket) } - fn read_async_variable_integer(stream: &mut S) -> impl Future> { - async move { - let mut integer = 0; - let mut length = 0; + async fn read_async_variable_integer(stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut integer = 0; + let mut length = 0; - for i in 0..4 { - let byte = stream.read_u8().await?; - length += 1; + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; - integer += (byte as usize & 0x7f) << (7 * i); + integer += (byte as usize & 0x7f) << (7 * i); - if (byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); } - Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } - fn write_async_variable_integer(&self, stream: &mut S) -> impl Future> { - async move { - let mut buf = [0u8; 4]; + async fn write_async_variable_integer(&self, stream: &mut S) -> Result { + let mut buf = [0u8; 4]; - if *self > 268_435_455 { - return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); - } + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self))); + } - let mut write = *self; - let mut length = 1; - - for i in 0..4 { - let mut byte = (write % 128) as u8; - write /= 128; - if write > 0 { - byte |= 128; - } - buf[i] = byte; - if write == 0 { - length = i + 1; - break; - } + let mut write = *self; + let mut length = 1; + + for (i, item) in buf.iter_mut().enumerate() { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + *item = byte; + if write == 0 { + length = i + 1; + break; } - stream.write_all(&buf[0..length]).await?; - Ok(length) } + stream.write_all(&buf[0..length]).await?; + Ok(length) } } @@ -207,50 +203,46 @@ impl VariableInteger for u32 { Err(DeserializeError::MalformedPacket) } - fn read_async_variable_integer(stream: &mut S) -> impl Future> { - async move { - let mut integer = 0; - let mut length = 0; + async fn read_async_variable_integer(stream: &mut S) -> Result<(Self, usize), ReadError> { + let mut integer = 0; + let mut length = 0; - for i in 0..4 { - let byte = stream.read_u8().await?; - length += 1; + for i in 0..4 { + let byte = stream.read_u8().await?; + length += 1; - integer += (byte as u32 & 0x7f) << (7 * i); + integer += (byte as u32 & 0x7f) << (7 * i); - if (byte & 0b1000_0000) == 0 { - return Ok((integer, length)); - } + if (byte & 0b1000_0000) == 0 { + return Ok((integer, length)); } - Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } + Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)) } - fn write_async_variable_integer(&self, stream: &mut S) -> impl Future> { - async move { - let mut buf = [0u8; 4]; + async fn write_async_variable_integer(&self, stream: &mut S) -> Result { + let mut buf = [0u8; 4]; - if *self > 268_435_455 { - return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); - } + if *self > 268_435_455 { + return Err(WriteError::SerializeError(SerializeError::VariableIntegerOverflow(*self as usize))); + } + + let mut write = *self; + let mut length = 1; - let mut write = *self; - let mut length = 1; - - for i in 0..4 { - let mut byte = (write % 128) as u8; - write /= 128; - if write > 0 { - byte |= 128; - } - buf[i] = byte; - if write == 0 { - length = i + 1; - break; - } + for (i, item) in buf.iter_mut().enumerate() { + let mut byte = (write % 128) as u8; + write /= 128; + if write > 0 { + byte |= 128; + } + *item = byte; + if write == 0 { + length = i + 1; + break; } - stream.write_all(&buf[0..length]).await?; - Ok(length) } + stream.write_all(&buf[0..length]).await?; + Ok(length) } } diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 67e3450..3dd5ed7 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -64,45 +64,43 @@ impl PacketAsyncRead for PubComp where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let packet_identifier = stream.read_u16().await?; - if remaining_length == 2 { - return Ok(( - Self { - packet_identifier, - reason_code: PubCompReasonCode::Success, - properties: PubCompProperties::default(), - }, - 2, - )); - } - // Requires u16, u8 and at least 1 byte of variable integer prop length so at least 4 bytes - else if remaining_length < 4 { - return Err(ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4))); - } - - let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; - let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; - - let total_read_bytes = 2 + reason_code_read_bytes + properties_read_bytes; - - if total_read_bytes != remaining_length { - return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { - read: total_read_bytes, - remaining_length: remaining_length, - })); - } - - Ok(( + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let packet_identifier = stream.read_u16().await?; + if remaining_length == 2 { + return Ok(( Self { packet_identifier, - reason_code, - properties, + reason_code: PubCompReasonCode::Success, + properties: PubCompProperties::default(), }, - total_read_bytes, - )) + 2, + )); + } + // Requires u16, u8 and at least 1 byte of variable integer prop length so at least 4 bytes + else if remaining_length < 4 { + return Err(ReadError::DeserializeError(DeserializeError::InsufficientData(std::any::type_name::(), 0, 4))); + } + + let (reason_code, reason_code_read_bytes) = PubCompReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubCompProperties::async_read(stream).await?; + + let total_read_bytes = 2 + reason_code_read_bytes + properties_read_bytes; + + if total_read_bytes != remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::RemainingDataError { + read: total_read_bytes, + remaining_length, + })); } + + Ok(( + Self { + packet_identifier, + reason_code, + properties, + }, + total_read_bytes, + )) } } @@ -126,22 +124,20 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for PubComp where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; - async move { - let mut total_writen_bytes = 2; - self.packet_identifier.async_write(stream).await?; - - if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); - } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; - } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; - } - Ok(total_writen_bytes) + let mut total_writen_bytes = 2; + self.packet_identifier.async_write(stream).await?; + + if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + return Ok(total_writen_bytes); + } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { + total_writen_bytes += self.reason_code.async_write(stream).await?; + } else { + total_writen_bytes += self.reason_code.async_write(stream).await?; + total_writen_bytes += self.properties.async_write(stream).await?; } + Ok(total_writen_bytes) } } diff --git a/mqrstt/src/packets/publish/mod.rs b/mqrstt/src/packets/publish/mod.rs index 71dc8a3..004dafc 100644 --- a/mqrstt/src/packets/publish/mod.rs +++ b/mqrstt/src/packets/publish/mod.rs @@ -96,46 +96,44 @@ impl PacketAsyncRead for Publish where S: tokio::io::AsyncRead + Unpin, { - fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let dup = flags & 0b1000 != 0; - let qos = QoS::from_u8((flags & 0b110) >> 1)?; - let retain = flags & 0b1 != 0; - - let (topic, topic_read_bytes) = Box::::async_read(stream).await?; - total_read_bytes += topic_read_bytes; - let packet_identifier = if qos == QoS::AtMostOnce { - None - } else { - total_read_bytes += 2; - Some(stream.read_u16().await?) - }; - let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; - total_read_bytes += properties_read_bytes; - - if total_read_bytes > remaining_length { - return Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)); - } - let payload_len = remaining_length - total_read_bytes; - let mut payload = vec![0u8; payload_len]; - let payload_read_bytes = stream.read_exact(&mut payload).await?; - - assert_eq!(payload_read_bytes, payload_len); - - Ok(( - Self { - dup, - qos, - retain, - topic, - packet_identifier, - publish_properties, - payload, - }, - total_read_bytes + payload_read_bytes, - )) + async fn async_read(flags: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let dup = flags & 0b1000 != 0; + let qos = QoS::from_u8((flags & 0b110) >> 1)?; + let retain = flags & 0b1 != 0; + + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_bytes; + let packet_identifier = if qos == QoS::AtMostOnce { + None + } else { + total_read_bytes += 2; + Some(stream.read_u16().await?) + }; + let (publish_properties, properties_read_bytes) = PublishProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; + + if total_read_bytes > remaining_length { + return Err(ReadError::DeserializeError(DeserializeError::MalformedPacket)); } + let payload_len = remaining_length - total_read_bytes; + let mut payload = vec![0u8; payload_len]; + let payload_read_bytes = stream.read_exact(&mut payload).await?; + + assert_eq!(payload_read_bytes, payload_len); + + Ok(( + Self { + dup, + qos, + retain, + topic, + packet_identifier, + publish_properties, + payload, + }, + total_read_bytes + payload_read_bytes, + )) } } diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 9fad939..5f02125 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -65,36 +65,34 @@ impl PacketAsyncRead for PubRec where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - total_read_bytes += 2; - if remaining_length == 2 { - return Ok(( - Self { - packet_identifier, - reason_code: PubRecReasonCode::Success, - properties: PubRecProperties::default(), - }, - total_read_bytes, - )); - } - - let (reason_code, reason_code_read_bytes) = PubRecReasonCode::async_read(stream).await?; - let (properties, properties_read_bytes) = PubRecProperties::async_read(stream).await?; - - total_read_bytes += reason_code_read_bytes + properties_read_bytes; - - Ok(( + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; + if remaining_length == 2 { + return Ok(( Self { packet_identifier, - properties, - reason_code, + reason_code: PubRecReasonCode::Success, + properties: PubRecProperties::default(), }, total_read_bytes, - )) + )); } + + let (reason_code, reason_code_read_bytes) = PubRecReasonCode::async_read(stream).await?; + let (properties, properties_read_bytes) = PubRecProperties::async_read(stream).await?; + + total_read_bytes += reason_code_read_bytes + properties_read_bytes; + + Ok(( + Self { + packet_identifier, + properties, + reason_code, + }, + total_read_bytes, + )) } } diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index 4bbd138..4f05abe 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -50,32 +50,30 @@ impl PacketAsyncRead for SubAck where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - let (properties, proproperties_read_bytes) = SubAckProperties::async_read(stream).await?; - total_read_bytes += 2 + proproperties_read_bytes; - let mut reason_codes = vec![]; - loop { - let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; - total_read_bytes += reason_code_read_bytes; - reason_codes.push(reason_code); - - if remaining_length == total_read_bytes { - break; - } - } + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, proproperties_read_bytes) = SubAckProperties::async_read(stream).await?; + total_read_bytes += 2 + proproperties_read_bytes; + let mut reason_codes = vec![]; + loop { + let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + reason_codes.push(reason_code); - Ok(( - Self { - packet_identifier, - properties, - reason_codes, - }, - total_read_bytes, - )) + if remaining_length == total_read_bytes { + break; + } } + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) } } @@ -96,22 +94,20 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for SubAck where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; - async move { - let mut total_written_bytes = 2; - stream.write_u16(self.packet_identifier).await?; - - total_written_bytes += self.properties.async_write(stream).await?; + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; - for reason_code in &self.reason_codes { - reason_code.async_write(stream).await?; - } - total_written_bytes += self.reason_codes.len(); + total_written_bytes += self.properties.async_write(stream).await?; - Ok(total_written_bytes) + for reason_code in &self.reason_codes { + reason_code.async_write(stream).await?; } + total_written_bytes += self.reason_codes.len(); + + Ok(total_written_bytes) } } diff --git a/mqrstt/src/packets/subscribe/mod.rs b/mqrstt/src/packets/subscribe/mod.rs index 28f4899..deb39ab 100644 --- a/mqrstt/src/packets/subscribe/mod.rs +++ b/mqrstt/src/packets/subscribe/mod.rs @@ -61,34 +61,32 @@ impl PacketAsyncRead for Subscribe where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - let (properties, proproperties_read_bytes) = SubscribeProperties::async_read(stream).await?; - total_read_bytes += 2 + proproperties_read_bytes; - - let mut topics = vec![]; - loop { - let (topic, topic_read_bytes) = Box::::async_read(stream).await?; - let (options, options_read_bytes) = SubscriptionOptions::async_read(stream).await?; - total_read_bytes += topic_read_bytes + options_read_bytes; - topics.push((topic, options)); - - if remaining_length >= total_read_bytes { - break; - } - } + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, proproperties_read_bytes) = SubscribeProperties::async_read(stream).await?; + total_read_bytes += 2 + proproperties_read_bytes; - Ok(( - Self { - packet_identifier, - properties, - topics, - }, - total_read_bytes, - )) + let mut topics = vec![]; + loop { + let (topic, topic_read_bytes) = Box::::async_read(stream).await?; + let (options, options_read_bytes) = SubscriptionOptions::async_read(stream).await?; + total_read_bytes += topic_read_bytes + options_read_bytes; + topics.push((topic, options)); + + if remaining_length >= total_read_bytes { + break; + } } + + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) } } @@ -110,20 +108,18 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for Subscribe where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; - async move { - let mut total_written_bytes = 2; - stream.write_u16(self.packet_identifier).await?; + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; - total_written_bytes += self.properties.async_write(stream).await?; - for (topic, options) in &self.topics { - total_written_bytes += topic.async_write(stream).await?; - total_written_bytes += options.async_write(stream).await?; - } - Ok(total_written_bytes) + total_written_bytes += self.properties.async_write(stream).await?; + for (topic, options) in &self.topics { + total_written_bytes += topic.async_write(stream).await?; + total_written_bytes += options.async_write(stream).await?; } + Ok(total_written_bytes) } } @@ -200,24 +196,22 @@ impl MqttAsyncRead for SubscriptionOptions where S: tokio::io::AsyncRead + Unpin, { - fn async_read(stream: &mut S) -> impl std::future::Future> { - async move { - let byte = stream.read_u8().await?; + async fn async_read(stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let byte = stream.read_u8().await?; - let retain_handling_part = (byte & 0b00110000) >> 4; - let retain_as_publish_part = (byte & 0b00001000) >> 3; - let no_local_part = (byte & 0b00000100) >> 2; - let qos_part = byte & 0b00000011; + let retain_handling_part = (byte & 0b00110000) >> 4; + let retain_as_publish_part = (byte & 0b00001000) >> 3; + let no_local_part = (byte & 0b00000100) >> 2; + let qos_part = byte & 0b00000011; - let options = Self { - retain_handling: RetainHandling::from_u8(retain_handling_part)?, - retain_as_publish: retain_as_publish_part != 0, - no_local: no_local_part != 0, - qos: QoS::from_u8(qos_part)?, - }; + let options = Self { + retain_handling: RetainHandling::from_u8(retain_handling_part)?, + retain_as_publish: retain_as_publish_part != 0, + no_local: no_local_part != 0, + qos: QoS::from_u8(qos_part)?, + }; - Ok((options, 1)) - } + Ok((options, 1)) } } @@ -386,7 +380,7 @@ where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &[T; S]) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } // -------------------- Slices -------------------- @@ -395,7 +389,7 @@ where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &[T]) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } // -------------------- Vecs -------------------- @@ -404,7 +398,7 @@ where T: IntoSingleSubscription, { fn from(value: Vec) -> Self { - Self(value.into_iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.into_iter().map(IntoSingleSubscription::into).collect()) } } impl From<&Vec> for SubscribeTopics @@ -412,7 +406,7 @@ where for<'any> &'any T: IntoSingleSubscription, { fn from(value: &Vec) -> Self { - Self(value.iter().map(|val| IntoSingleSubscription::into(val)).collect()) + Self(value.iter().map(IntoSingleSubscription::into).collect()) } } diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs index da06d52..3153f3b 100644 --- a/mqrstt/src/packets/unsuback/mod.rs +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -49,36 +49,34 @@ impl PacketAsyncRead for UnsubAck where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - total_read_bytes += 2; + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + total_read_bytes += 2; - let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; - total_read_bytes += properties_read_bytes; + let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; + total_read_bytes += properties_read_bytes; - let mut reason_codes = vec![]; - loop { - let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; - total_read_bytes += reason_code_read_bytes; + let mut reason_codes = vec![]; + loop { + let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; - reason_codes.push(reason_code); + reason_codes.push(reason_code); - if total_read_bytes >= remaining_length { - break; - } + if total_read_bytes >= remaining_length { + break; } - - Ok(( - Self { - packet_identifier, - properties, - reason_codes, - }, - total_read_bytes, - )) } + + Ok(( + Self { + packet_identifier, + properties, + reason_codes, + }, + total_read_bytes, + )) } } diff --git a/mqrstt/src/packets/unsubscribe/mod.rs b/mqrstt/src/packets/unsubscribe/mod.rs index 6e9af89..9051f24 100644 --- a/mqrstt/src/packets/unsubscribe/mod.rs +++ b/mqrstt/src/packets/unsubscribe/mod.rs @@ -58,34 +58,32 @@ impl PacketAsyncRead for Unsubscribe where S: tokio::io::AsyncRead + Unpin, { - fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> impl std::future::Future> { - async move { - let mut total_read_bytes = 0; - let packet_identifier = stream.read_u16().await?; - let (properties, properties_read_bytes) = UnsubscribeProperties::async_read(stream).await?; - total_read_bytes += 2 + properties_read_bytes; - - let mut topics = vec![]; - loop { - let (topic, topic_read_size) = Box::::async_read(stream).await?; - total_read_bytes += topic_read_size; - - topics.push(topic); - - if total_read_bytes >= remaining_length { - break; - } - } + async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { + let mut total_read_bytes = 0; + let packet_identifier = stream.read_u16().await?; + let (properties, properties_read_bytes) = UnsubscribeProperties::async_read(stream).await?; + total_read_bytes += 2 + properties_read_bytes; + + let mut topics = vec![]; + loop { + let (topic, topic_read_size) = Box::::async_read(stream).await?; + total_read_bytes += topic_read_size; + + topics.push(topic); - Ok(( - Self { - packet_identifier, - properties, - topics, - }, - total_read_bytes, - )) + if total_read_bytes >= remaining_length { + break; + } } + + Ok(( + Self { + packet_identifier, + properties, + topics, + }, + total_read_bytes, + )) } } @@ -105,19 +103,18 @@ impl crate::packets::mqtt_trait::PacketAsyncWrite for Unsubscribe where S: tokio::io::AsyncWrite + Unpin, { - fn async_write(&self, stream: &mut S) -> impl std::future::Future> { + async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; - async move { - let mut total_written_bytes = 2; - stream.write_u16(self.packet_identifier).await?; - for topic in &self.topics { - total_written_bytes += topic.async_write(stream).await?; - } + let mut total_written_bytes = 2; + stream.write_u16(self.packet_identifier).await?; - Ok(total_written_bytes) + for topic in &self.topics { + total_written_bytes += topic.async_write(stream).await?; } + + Ok(total_written_bytes) } } @@ -206,7 +203,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &[T; S]) -> Self { - Self(value.iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } // -------------------- Slices -------------------- @@ -215,7 +212,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &[T]) -> Self { - Self(value.iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } impl From<&[&str]> for UnsubscribeTopics { @@ -239,7 +236,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: Vec) -> Self { - Self(value.into_iter().map(|val| IntoUnsubscribeTopic::into(&val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } @@ -248,7 +245,7 @@ where for<'any> &'any T: IntoUnsubscribeTopic, { fn from(value: &Vec) -> Self { - Self(value.into_iter().map(|val| IntoUnsubscribeTopic::into(val)).collect()) + Self(value.iter().map(IntoUnsubscribeTopic::into).collect()) } } diff --git a/mqrstt/src/state.rs b/mqrstt/src/state.rs index 3c52951..4ece423 100644 --- a/mqrstt/src/state.rs +++ b/mqrstt/src/state.rs @@ -34,7 +34,7 @@ pub struct State { impl State { pub fn new(receive_maximum: u16, apkid: AvailablePacketIds) -> Self { - let state = Self { + Self { apkid, outgoing_sub: Mutex::new(BTreeSet::new()), @@ -45,9 +45,7 @@ impl State { outgoing_pub_order: Mutex::new(VecDeque::new()), outgoing_rel: Mutex::new(BTreeSet::new()), incoming_pub: Mutex::new(BTreeSet::new()), - }; - - state + } } pub fn make_pkid_available(&self, pkid: u16) -> Result<(), HandlerError> { diff --git a/mqrstt/src/tokio/network.rs b/mqrstt/src/tokio/network.rs index 3caa0e4..bdfb57b 100644 --- a/mqrstt/src/tokio/network.rs +++ b/mqrstt/src/tokio/network.rs @@ -96,13 +96,10 @@ where return Err(ConnectionError::NoNetwork); } - match self.tokio_select(handler).await { - otherwise => { - self.network = None; + let result = self.tokio_select(handler).await; + self.network = None; - otherwise - } - } + result } async fn tokio_select(&mut self, handler: &mut H) -> Result { diff --git a/mqrstt/src/tokio/stream.rs b/mqrstt/src/tokio/stream.rs index a5647e7..9a720ad 100644 --- a/mqrstt/src/tokio/stream.rs +++ b/mqrstt/src/tokio/stream.rs @@ -19,51 +19,47 @@ impl StreamExt for S where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Sized + Unpin, { - fn connect(&mut self, options: &ConnectOptions) -> impl std::future::Future> { - async move { - let connect = options.create_connect_from_options(); + async fn connect(&mut self, options: &ConnectOptions) -> Result { + let connect = options.create_connect_from_options(); - self.write_packet(&connect).await?; + self.write_packet(&connect).await?; - let packet = Packet::async_read(self).await?; - if let Packet::ConnAck(con) = packet { - if con.reason_code == ConnAckReasonCode::Success { - #[cfg(feature = "logs")] - trace!("Connected to server"); - Ok(con) - } else { - Err(ConnectionError::ConnectionRefused(con.reason_code)) - } + let packet = Packet::async_read(self).await?; + if let Packet::ConnAck(con) = packet { + if con.reason_code == ConnAckReasonCode::Success { + #[cfg(feature = "logs")] + trace!("Connected to server"); + Ok(con) } else { - Err(ConnectionError::NotConnAck(packet)) + Err(ConnectionError::ConnectionRefused(con.reason_code)) } + } else { + Err(ConnectionError::NotConnAck(packet)) } } - fn read_packet(&mut self) -> impl std::future::Future> { - async move { Ok(Packet::async_read(self).await?) } + async fn read_packet(&mut self) -> Result { + Ok(Packet::async_read(self).await?) } - fn write_packet(&mut self, packet: &Packet) -> impl std::future::Future> { - async move { - match packet.async_write(self).await { - Ok(_) => (), - Err(err) => { - return match err { - crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), - crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), - } + async fn write_packet(&mut self, packet: &Packet) -> Result<(), ConnectionError> { + match packet.async_write(self).await { + Ok(_) => (), + Err(err) => { + return match err { + crate::packets::error::WriteError::SerializeError(serialize_error) => Err(ConnectionError::SerializationError(serialize_error)), + crate::packets::error::WriteError::IoError(error) => Err(ConnectionError::Io(error)), } } + } - #[cfg(feature = "logs")] - trace!("Sending packet {}", packet); + #[cfg(feature = "logs")] + trace!("Sending packet {}", packet); - self.flush().await?; - // self.flush_packets().await?; + self.flush().await?; + // self.flush_packets().await?; - Ok(()) - } + Ok(()) } async fn write_packets(&mut self, packets: &[Packet]) -> Result<(), ConnectionError> {