diff --git a/mqrstt/src/lib.rs b/mqrstt/src/lib.rs index e4d0477..e5ac218 100644 --- a/mqrstt/src/lib.rs +++ b/mqrstt/src/lib.rs @@ -409,7 +409,7 @@ mod tokio_lib_test { network.connect(stream, &mut pingresp).await.unwrap(); let network_handle = tokio::task::spawn(async move { - let result = network.run(&mut pingresp).await; + let _result = network.run(&mut pingresp).await; // check result and or restart the connection pingresp }); diff --git a/mqrstt/src/packets/auth/mod.rs b/mqrstt/src/packets/auth/mod.rs index e16fdda..6f55d4d 100644 --- a/mqrstt/src/packets/auth/mod.rs +++ b/mqrstt/src/packets/auth/mod.rs @@ -48,9 +48,9 @@ where S: tokio::io::AsyncWrite + Unpin, { async fn async_write(&self, stream: &mut S) -> Result { - let reason_code_writen = self.reason_code.async_write(stream).await?; - let properties_writen = self.properties.async_write(stream).await?; - Ok(reason_code_writen + properties_writen) + let reason_code_written = self.reason_code.async_write(stream).await?; + let properties_written = self.properties.async_write(stream).await?; + Ok(reason_code_written + properties_written) } } diff --git a/mqrstt/src/packets/connack/mod.rs b/mqrstt/src/packets/connack/mod.rs index 94da71f..90dc715 100644 --- a/mqrstt/src/packets/connack/mod.rs +++ b/mqrstt/src/packets/connack/mod.rs @@ -85,11 +85,11 @@ where { async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; - let connack_flags_writen = self.connack_flags.async_write(stream).await?; - let reason_code_writen = self.reason_code.async_write(stream).await?; - let connack_properties_writen = self.connack_properties.async_write(stream).await?; + let connack_flags_written = self.connack_flags.async_write(stream).await?; + let reason_code_written = self.reason_code.async_write(stream).await?; + let connack_properties_written = self.connack_properties.async_write(stream).await?; - Ok(connack_flags_writen + reason_code_writen + connack_properties_writen) + Ok(connack_flags_written + reason_code_written + connack_properties_written) } } diff --git a/mqrstt/src/packets/connect/mod.rs b/mqrstt/src/packets/connect/mod.rs index 463f6e4..cb5a611 100644 --- a/mqrstt/src/packets/connect/mod.rs +++ b/mqrstt/src/packets/connect/mod.rs @@ -223,12 +223,12 @@ where use crate::packets::mqtt_trait::MqttAsyncWrite; use tokio::io::AsyncWriteExt; async move { - let mut total_writen_bytes = 6 // protocol header + let mut total_written_bytes = 6 // protocol header + 1 // protocol version + 1 // connect flags + 2; // keep alive let protocol = [0x00, 0x04, b'M', b'Q', b'T', b'T']; - // We allready start with 6 as total writen bytes thus dont add anymore + // We allready start with 6 as total written bytes thus dont add anymore stream.write_all(&protocol).await?; self.protocol_version.async_write(stream).await?; @@ -250,21 +250,21 @@ where stream.write_u16(self.keep_alive).await?; - total_writen_bytes += self.connect_properties.async_write(stream).await?; + total_written_bytes += self.connect_properties.async_write(stream).await?; - total_writen_bytes += self.client_id.async_write(stream).await?; + total_written_bytes += self.client_id.async_write(stream).await?; if let Some(last_will) = &self.last_will { - total_writen_bytes += last_will.async_write(stream).await?; + total_written_bytes += last_will.async_write(stream).await?; } if let Some(username) = &self.username { - total_writen_bytes += username.async_write(stream).await?; + total_written_bytes += username.async_write(stream).await?; } if let Some(password) = &self.password { - total_writen_bytes += password.async_write(stream).await?; + total_written_bytes += password.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } } diff --git a/mqrstt/src/packets/macros/properties_macros.rs b/mqrstt/src/packets/macros/properties_macros.rs index b71e149..33d80ba 100644 --- a/mqrstt/src/packets/macros/properties_macros.rs +++ b/mqrstt/src/packets/macros/properties_macros.rs @@ -35,13 +35,13 @@ macro_rules! define_properties { impl $crate::packets::mqtt_trait::MqttAsyncWrite for $name where S: tokio::io::AsyncWrite + Unpin { async fn async_write(&self, stream: &mut S) -> Result { - let mut bytes_writen = 0; + let mut bytes_written = 0; $crate::packets::VariableInteger::write_async_variable_integer(&self.wire_len(), stream).await?; $( - $crate::packets::macros::properties_write!(self, bytes_writen, stream, PropertyType::$prop_variant); + $crate::packets::macros::properties_write!(self, bytes_written, stream, PropertyType::$prop_variant); )* - Ok(bytes_writen) + Ok(bytes_written) } } @@ -556,173 +556,173 @@ macro_rules! properties_read_match_branch_name { } macro_rules! properties_write { - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::PayloadFormatIndicator) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::PayloadFormatIndicator) => { if let Some(payload_format_indicator) = &($self.payload_format_indicator) { - $bytes_writen += PropertyType::PayloadFormatIndicator.async_write($stream).await?; - $bytes_writen += payload_format_indicator.async_write($stream).await?; + $bytes_written += PropertyType::PayloadFormatIndicator.async_write($stream).await?; + $bytes_written += payload_format_indicator.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MessageExpiryInterval) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MessageExpiryInterval) => { if let Some(message_expiry_interval) = &($self.message_expiry_interval) { - $bytes_writen += PropertyType::MessageExpiryInterval.async_write($stream).await?; - $bytes_writen += message_expiry_interval.async_write($stream).await?; + $bytes_written += PropertyType::MessageExpiryInterval.async_write($stream).await?; + $bytes_written += message_expiry_interval.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ContentType) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ContentType) => { if let Some(content_type) = &($self.content_type) { - $bytes_writen += PropertyType::ContentType.async_write($stream).await?; - $bytes_writen += content_type.async_write($stream).await?; + $bytes_written += PropertyType::ContentType.async_write($stream).await?; + $bytes_written += content_type.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ResponseTopic) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ResponseTopic) => { if let Some(response_topic) = &($self.response_topic) { - $bytes_writen += PropertyType::ResponseTopic.async_write($stream).await?; - $bytes_writen += response_topic.as_ref().async_write($stream).await?; + $bytes_written += PropertyType::ResponseTopic.async_write($stream).await?; + $bytes_written += response_topic.as_ref().async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::CorrelationData) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::CorrelationData) => { if let Some(correlation_data) = &($self.correlation_data) { - $bytes_writen += PropertyType::CorrelationData.async_write($stream).await?; - $bytes_writen += correlation_data.async_write($stream).await?; + $bytes_written += PropertyType::CorrelationData.async_write($stream).await?; + $bytes_written += correlation_data.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SubscriptionIdentifier) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SubscriptionIdentifier) => { if let Some(sub_id) = &($self.subscription_identifier) { - $bytes_writen += PropertyType::SubscriptionIdentifier.async_write($stream).await?; - $bytes_writen += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + $bytes_written += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_written += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ListSubscriptionIdentifier) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ListSubscriptionIdentifier) => { for sub_id in &($self.subscription_identifiers) { - $bytes_writen += PropertyType::SubscriptionIdentifier.async_write($stream).await?; - $bytes_writen += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; + $bytes_written += PropertyType::SubscriptionIdentifier.async_write($stream).await?; + $bytes_written += $crate::packets::primitive::VariableInteger::write_async_variable_integer(sub_id, $stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SessionExpiryInterval) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SessionExpiryInterval) => { if let Some(session_expiry_interval) = &($self.session_expiry_interval) { - $bytes_writen += PropertyType::SessionExpiryInterval.async_write($stream).await?; - $bytes_writen += session_expiry_interval.async_write($stream).await?; + $bytes_written += PropertyType::SessionExpiryInterval.async_write($stream).await?; + $bytes_written += session_expiry_interval.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AssignedClientIdentifier) => {}; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ServerKeepAlive) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AssignedClientIdentifier) => {}; + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ServerKeepAlive) => { if let Some(server_keep_alive) = &($self.server_keep_alive) { - $bytes_writen += PropertyType::ServerKeepAlive.async_write($stream).await?; - $bytes_writen += server_keep_alive.async_write($stream).await?; + $bytes_written += PropertyType::ServerKeepAlive.async_write($stream).await?; + $bytes_written += server_keep_alive.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AuthenticationMethod) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AuthenticationMethod) => { if let Some(authentication_method) = &($self.authentication_method) { - $bytes_writen += PropertyType::AuthenticationMethod.async_write($stream).await?; - $bytes_writen += authentication_method.async_write($stream).await?; + $bytes_written += PropertyType::AuthenticationMethod.async_write($stream).await?; + $bytes_written += authentication_method.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::AuthenticationData) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::AuthenticationData) => { if let Some(authentication_data) = &($self.authentication_data) { if !authentication_data.is_empty() && ($self.authentication_method).is_some() { - $bytes_writen += PropertyType::AuthenticationData.async_write($stream).await?; - $bytes_writen += authentication_data.async_write($stream).await?; + $bytes_written += PropertyType::AuthenticationData.async_write($stream).await?; + $bytes_written += authentication_data.async_write($stream).await?; } } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RequestProblemInformation) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RequestProblemInformation) => { if let Some(request_problem_information) = &($self.request_problem_information) { - $bytes_writen += PropertyType::RequestProblemInformation.async_write($stream).await?; - $bytes_writen += request_problem_information.async_write($stream).await?; + $bytes_written += PropertyType::RequestProblemInformation.async_write($stream).await?; + $bytes_written += request_problem_information.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::WillDelayInterval) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::WillDelayInterval) => { if let Some(delay_interval) = &($self.will_delay_interval) { - $bytes_writen += PropertyType::WillDelayInterval.async_write($stream).await?; - $bytes_writen += delay_interval.async_write($stream).await?; + $bytes_written += PropertyType::WillDelayInterval.async_write($stream).await?; + $bytes_written += delay_interval.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RequestResponseInformation) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RequestResponseInformation) => { if let Some(request_response_information) = &($self.request_response_information) { - $bytes_writen += PropertyType::RequestResponseInformation.async_write($stream).await?; - $bytes_writen += request_response_information.async_write($stream).await?; + $bytes_written += PropertyType::RequestResponseInformation.async_write($stream).await?; + $bytes_written += request_response_information.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ResponseInformation) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ResponseInformation) => { if let Some(response_info) = &($self.response_info) { - $bytes_writen += PropertyType::ResponseInformation.async_write($stream).await?; - $bytes_writen += response_info.async_write($stream).await?; + $bytes_written += PropertyType::ResponseInformation.async_write($stream).await?; + $bytes_written += response_info.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ServerReference) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ServerReference) => { if let Some(server_refrence) = &($self.server_reference) { - $bytes_writen += PropertyType::ServerReference.async_write($stream).await?; + $bytes_written += PropertyType::ServerReference.async_write($stream).await?; server_refrence.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ReasonString) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ReasonString) => { if let Some(reason_string) = &($self.reason_string) { - $bytes_writen += PropertyType::ReasonString.async_write($stream).await?; - $bytes_writen += reason_string.async_write($stream).await?; + $bytes_written += PropertyType::ReasonString.async_write($stream).await?; + $bytes_written += reason_string.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::ReceiveMaximum) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::ReceiveMaximum) => { if let Some(receive_maximum) = &($self.receive_maximum) { - $bytes_writen += PropertyType::ReceiveMaximum.async_write($stream).await?; - $bytes_writen += receive_maximum.async_write($stream).await?; + $bytes_written += PropertyType::ReceiveMaximum.async_write($stream).await?; + $bytes_written += receive_maximum.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::TopicAliasMaximum) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::TopicAliasMaximum) => { if let Some(topic_alias_maximum) = &($self.topic_alias_maximum) { - $bytes_writen += PropertyType::TopicAliasMaximum.async_write($stream).await?; - $bytes_writen += topic_alias_maximum.async_write($stream).await?; + $bytes_written += PropertyType::TopicAliasMaximum.async_write($stream).await?; + $bytes_written += topic_alias_maximum.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::TopicAlias) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::TopicAlias) => { if let Some(topic_alias) = &($self.topic_alias) { - $bytes_writen += PropertyType::TopicAlias.async_write($stream).await?; - $bytes_writen += topic_alias.async_write($stream).await?; + $bytes_written += PropertyType::TopicAlias.async_write($stream).await?; + $bytes_written += topic_alias.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MaximumQos) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MaximumQos) => { if let Some(maximum_qos) = &($self.maximum_qos) { - $bytes_writen += PropertyType::MaximumQos.async_write($stream).await?; - $bytes_writen += maximum_qos.async_write($stream).await?; + $bytes_written += PropertyType::MaximumQos.async_write($stream).await?; + $bytes_written += maximum_qos.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::RetainAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::RetainAvailable) => { if let Some(retain_available) = &($self.retain_available) { - $bytes_writen += PropertyType::RetainAvailable.async_write($stream).await?; - $bytes_writen += retain_available.async_write($stream).await?; + $bytes_written += PropertyType::RetainAvailable.async_write($stream).await?; + $bytes_written += retain_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::UserProperty) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::UserProperty) => { for (key, value) in &($self.user_properties) { - $bytes_writen += PropertyType::UserProperty.async_write($stream).await?; - $bytes_writen += key.async_write($stream).await?; - $bytes_writen += value.async_write($stream).await?; + $bytes_written += PropertyType::UserProperty.async_write($stream).await?; + $bytes_written += key.async_write($stream).await?; + $bytes_written += value.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::MaximumPacketSize) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::MaximumPacketSize) => { if let Some(maximum_packet_size) = &($self.maximum_packet_size) { - $bytes_writen += PropertyType::MaximumPacketSize.async_write($stream).await?; - $bytes_writen += maximum_packet_size.async_write($stream).await?; + $bytes_written += PropertyType::MaximumPacketSize.async_write($stream).await?; + $bytes_written += maximum_packet_size.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::WildcardSubscriptionAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::WildcardSubscriptionAvailable) => { if let Some(wildcards_available) = &($self.wildcards_available) { - $bytes_writen += PropertyType::WildcardSubscriptionAvailable.async_write($stream).await?; - $bytes_writen += wildcards_available.async_write($stream).await?; + $bytes_written += PropertyType::WildcardSubscriptionAvailable.async_write($stream).await?; + $bytes_written += wildcards_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SubscriptionIdentifierAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SubscriptionIdentifierAvailable) => { if let Some(subscription_ids_available) = &($self.subscription_ids_available) { - $bytes_writen += PropertyType::SubscriptionIdentifierAvailable.async_write($stream).await?; - $bytes_writen += subscription_ids_available.async_write($stream).await?; + $bytes_written += PropertyType::SubscriptionIdentifierAvailable.async_write($stream).await?; + $bytes_written += subscription_ids_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, PropertyType::SharedSubscriptionAvailable) => { + ($self:ident, $bytes_written:ident, $stream:ident, PropertyType::SharedSubscriptionAvailable) => { if let Some(shared_subscription_available) = &($self.shared_subscription_available) { - $bytes_writen += PropertyType::SharedSubscriptionAvailable.async_write($stream).await?; - $bytes_writen += shared_subscription_available.async_write($stream).await?; + $bytes_written += PropertyType::SharedSubscriptionAvailable.async_write($stream).await?; + $bytes_written += shared_subscription_available.async_write($stream).await?; } }; - ($self:ident, $bytes_writen:ident, $stream:ident, $unknown:ident) => { + ($self:ident, $bytes_written:ident, $stream:ident, $unknown:ident) => { compile_error!(concat!("Unknown property: ", stringify!($unknown))); }; } diff --git a/mqrstt/src/packets/mod.rs b/mqrstt/src/packets/mod.rs index 2eb135a..39fbc88 100644 --- a/mqrstt/src/packets/mod.rs +++ b/mqrstt/src/packets/mod.rs @@ -103,9 +103,7 @@ impl Packet { Packet::PubRel(_) => 0b0110_0010, Packet::PubComp(_) => 0b0111_0000, Packet::Subscribe(_) => 0b1000_0010, - Packet::SubAck(_) => { - unreachable!() - } + Packet::SubAck(_) => 0b1001_0000, Packet::Unsubscribe(_) => 0b1010_0010, Packet::UnsubAck(_) => 0b1011_0000, Packet::PingReq => 0b1100_0000, @@ -168,17 +166,20 @@ impl Packet { p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } - Packet::SubAck(_) => { - unreachable!() + Packet::SubAck(p) => { + buf.put_u8(0b1001_0000); + p.wire_len().write_variable_integer(buf)?; + p.write(buf)?; } Packet::Unsubscribe(p) => { buf.put_u8(0b1010_0010); p.wire_len().write_variable_integer(buf)?; p.write(buf)?; } - Packet::UnsubAck(_) => { - unreachable!(); - // buf.put_u8(0b1011_0000); + Packet::UnsubAck(p) => { + buf.put_u8(0b1011_0000); + p.wire_len().write_variable_integer(buf)?; + p.write(buf)?; } Packet::PingReq => { buf.put_u8(0b1100_0000); @@ -259,17 +260,20 @@ impl Packet { written += p.wire_len().write_async_variable_integer(stream).await?; written += p.async_write(stream).await?; } - Packet::SubAck(_) => { - unreachable!() + Packet::SubAck(p) => { + stream.write_u8(0b1001_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; } Packet::Unsubscribe(p) => { stream.write_u8(0b1010_0010).await?; written += p.wire_len().write_async_variable_integer(stream).await?; written += p.async_write(stream).await?; } - Packet::UnsubAck(_) => { - unreachable!(); - // stream.write_u8(0b1011_0000).await?; + Packet::UnsubAck(p) => { + stream.write_u8(0b1011_0000).await?; + written += p.wire_len().write_async_variable_integer(stream).await?; + written += p.async_write(stream).await?; } Packet::PingReq => { stream.write_u8(0b1100_0000).await?; @@ -396,6 +400,28 @@ impl Display for Packet { } } +impl WireLength for Packet { + fn wire_len(&self) -> usize { + match self { + Packet::Connect(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::ConnAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Publish(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubRec(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubRel(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PubComp(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Subscribe(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::SubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Unsubscribe(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::UnsubAck(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::PingReq => 2, + Packet::PingResp => 2, + Packet::Disconnect(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + Packet::Auth(p) => 1 + p.wire_len().variable_integer_len() + p.wire_len(), + } + } +} + /// 2.1.2 MQTT Control Packet type #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)] pub enum PacketType { @@ -447,6 +473,7 @@ impl std::fmt::Display for PacketType { #[cfg(test)] mod tests { + use bytes::BytesMut; use crate::packets::Packet; @@ -454,26 +481,39 @@ mod tests { use crate::tests::test_packets::*; #[rstest::rstest] - #[case(ping_req_case().1)] - #[case(ping_resp_case().1)] - #[case(connack_case().1)] - #[case(create_subscribe_packet(1))] - #[case(create_subscribe_packet(65335))] - #[case(create_puback_packet(1))] - #[case(create_puback_packet(65335))] - #[case(create_disconnect_packet())] - #[case(create_connack_packet(true))] - #[case(create_connack_packet(false))] - #[case(publish_packet_1())] - #[case(publish_packet_2())] - #[case(publish_packet_3())] - #[case(publish_packet_4())] - #[case(create_empty_publish_packet())] + #[case::ping_req_case(ping_req_case().1)] + #[case::ping_resp_case(ping_resp_case().1)] + #[case::connack_case(connack_case().1)] + #[case::create_subscribe_packet(create_subscribe_packet(1))] + #[case::create_subscribe_packet(create_subscribe_packet(65335))] + #[case::create_puback_packet(create_puback_packet(1))] + #[case::create_puback_packet(create_puback_packet(65335))] + #[case::create_disconnect_packet(create_disconnect_packet())] + #[case::create_connack_packet(create_connack_packet(true))] + #[case::create_connack_packet(create_connack_packet(false))] + #[case::publish_packet_1(publish_packet_1())] + #[case::publish_packet_2(publish_packet_2())] + #[case::publish_packet_3(publish_packet_3())] + #[case::publish_packet_4(publish_packet_4())] + #[case::create_empty_publish_packet(create_empty_publish_packet())] + #[case::subscribe(subscribe_case())] + #[case::suback(suback_case())] + #[case::unsubscribe(unsubscribe_case())] + #[case::unsuback(unsuback_case())] fn test_write_read_write_read_cases(#[case] packet: Packet) { + use crate::packets::WireLength; + let mut buffer = BytesMut::new(); + packet.write(&mut buffer).unwrap(); + + let wire_len = packet.wire_len(); + assert_eq!(wire_len, buffer.len()); + let res1 = Packet::read(&mut buffer).unwrap(); + assert_eq!(packet, res1); + let mut buffer = BytesMut::new(); res1.write(&mut buffer).unwrap(); let res2 = Packet::read(&mut buffer).unwrap(); @@ -533,6 +573,43 @@ mod tests { assert_eq!(out, input) } + #[rstest::rstest] + #[case::ping_req_case(ping_req_case().1)] + #[case::ping_resp_case(ping_resp_case().1)] + #[case::connack_case(connack_case().1)] + #[case::create_subscribe_packet(create_subscribe_packet(1))] + #[case::create_subscribe_packet(create_subscribe_packet(65335))] + #[case::create_puback_packet(create_puback_packet(1))] + #[case::create_puback_packet(create_puback_packet(65335))] + #[case::create_disconnect_packet(create_disconnect_packet())] + #[case::create_connack_packet(create_connack_packet(true))] + #[case::create_connack_packet(create_connack_packet(false))] + #[case::publish_packet_1(publish_packet_1())] + #[case::publish_packet_2(publish_packet_2())] + #[case::publish_packet_3(publish_packet_3())] + #[case::publish_packet_4(publish_packet_4())] + #[case::create_empty_publish_packet(create_empty_publish_packet())] + #[case::subscribe(subscribe_case())] + #[case::suback(suback_case())] + #[case::unsubscribe(unsubscribe_case())] + #[case::unsuback(unsuback_case())] + #[tokio::test] + async fn test_async_write_read_write_read_cases(#[case] packet: Packet) { + use crate::packets::WireLength; + + let mut buffer = Vec::with_capacity(1000); + packet.async_write(&mut buffer).await.unwrap(); + + let wire_len = packet.wire_len(); + assert_eq!(wire_len, buffer.len()); + + let mut buf = buffer.as_slice(); + + let res1 = Packet::async_read(&mut buf).await.unwrap(); + + assert_eq!(packet, res1); + } + // #[rstest::rstest] // #[case(&[59, 1, 0, 59])] // #[case(&[16, 14, 0, 4, 77, 81, 84, 84, 5, 247, 247, 252, 1, 17, 247, 247, 247])] diff --git a/mqrstt/src/packets/pubcomp/mod.rs b/mqrstt/src/packets/pubcomp/mod.rs index 3dd5ed7..459b036 100644 --- a/mqrstt/src/packets/pubcomp/mod.rs +++ b/mqrstt/src/packets/pubcomp/mod.rs @@ -126,18 +126,18 @@ where { async fn async_write(&self, stream: &mut S) -> Result { use crate::packets::mqtt_trait::MqttAsyncWrite; - let mut total_writen_bytes = 2; + let mut total_written_bytes = 2; self.packet_identifier.async_write(stream).await?; if self.reason_code == PubCompReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); + return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } diff --git a/mqrstt/src/packets/pubrec/mod.rs b/mqrstt/src/packets/pubrec/mod.rs index 5f02125..92ca8be 100644 --- a/mqrstt/src/packets/pubrec/mod.rs +++ b/mqrstt/src/packets/pubrec/mod.rs @@ -118,18 +118,18 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_writen_bytes = 2; + let mut total_written_bytes = 2; self.packet_identifier.async_write(stream).await?; if self.reason_code == PubRecReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); + return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } } diff --git a/mqrstt/src/packets/pubrel/mod.rs b/mqrstt/src/packets/pubrel/mod.rs index c97d285..03d78d9 100644 --- a/mqrstt/src/packets/pubrel/mod.rs +++ b/mqrstt/src/packets/pubrel/mod.rs @@ -116,18 +116,18 @@ where fn async_write(&self, stream: &mut S) -> impl std::future::Future> { use crate::packets::mqtt_trait::MqttAsyncWrite; async move { - let mut total_writen_bytes = 2; + let mut total_written_bytes = 2; self.packet_identifier.async_write(stream).await?; if self.reason_code == PubRelReasonCode::Success && self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - return Ok(total_writen_bytes); + return Ok(total_written_bytes); } else if self.properties.reason_string.is_none() && self.properties.user_properties.is_empty() { - total_writen_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; } else { - total_writen_bytes += self.reason_code.async_write(stream).await?; - total_writen_bytes += self.properties.async_write(stream).await?; + total_written_bytes += self.reason_code.async_write(stream).await?; + total_written_bytes += self.properties.async_write(stream).await?; } - Ok(total_writen_bytes) + Ok(total_written_bytes) } } } diff --git a/mqrstt/src/packets/suback/mod.rs b/mqrstt/src/packets/suback/mod.rs index 4f05abe..df8b1a2 100644 --- a/mqrstt/src/packets/suback/mod.rs +++ b/mqrstt/src/packets/suback/mod.rs @@ -1,4 +1,5 @@ mod properties; + pub use properties::SubAckProperties; mod reason_code; @@ -7,6 +8,7 @@ pub use reason_code::SubAckReasonCode; use super::{ error::SerializeError, mqtt_trait::{MqttAsyncRead, MqttRead, MqttWrite, PacketAsyncRead, PacketRead, PacketWrite}, + VariableInteger, WireLength, }; use bytes::BufMut; use tokio::io::AsyncReadExt; @@ -23,19 +25,21 @@ pub struct SubAck { } impl PacketRead for SubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = SubAckProperties::read(&mut buf)?; let mut reason_codes = vec![]; - loop { - let reason_code = SubAckReasonCode::read(&mut buf)?; - - reason_codes.push(reason_code); - if buf.is_empty() { + let mut read = 2 + properties.wire_len().variable_integer_len() + properties.wire_len(); + loop { + if read >= remaining_length { break; } + + let reason_code = SubAckReasonCode::read(&mut buf)?; + reason_codes.push(reason_code); + read += 1; } Ok(Self { @@ -57,13 +61,13 @@ where total_read_bytes += 2 + proproperties_read_bytes; let mut reason_codes = vec![]; loop { - let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; - total_read_bytes += reason_code_read_bytes; - reason_codes.push(reason_code); - if remaining_length == total_read_bytes { break; } + + let (reason_code, reason_code_read_bytes) = SubAckReasonCode::async_read(stream).await?; + total_read_bytes += reason_code_read_bytes; + reason_codes.push(reason_code); } Ok(( @@ -111,6 +115,12 @@ where } } +impl WireLength for SubAck { + fn wire_len(&self) -> usize { + 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + self.reason_codes.len() + } +} + #[cfg(test)] mod test { use bytes::BytesMut; @@ -128,7 +138,7 @@ mod test { ]; let data = BytesMut::from(&buf[..]); - let sub_ack = SubAck::read(0, 0, data.clone().into()).unwrap(); + let sub_ack = SubAck::read(0, 5, data.clone().into()).unwrap(); let mut result = BytesMut::new(); sub_ack.write(&mut result).unwrap(); diff --git a/mqrstt/src/packets/suback/properties.rs b/mqrstt/src/packets/suback/properties.rs index be7ac8b..15f6997 100644 --- a/mqrstt/src/packets/suback/properties.rs +++ b/mqrstt/src/packets/suback/properties.rs @@ -44,7 +44,7 @@ impl MqttRead for SubAckProperties { e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::SubAck)), } - if buf.is_empty() { + if properties_data.is_empty() { break; } } @@ -58,7 +58,6 @@ impl MqttWrite for SubAckProperties { if let Some(sub_id) = self.subscription_identifier { PropertyType::SubscriptionIdentifier.write(buf)?; sub_id.write_variable_integer(buf)?; - // sub_id.write_variable_integer(buf)?; } for (key, value) in &self.user_properties { PropertyType::UserProperty.write(buf)?; diff --git a/mqrstt/src/packets/unsuback/mod.rs b/mqrstt/src/packets/unsuback/mod.rs index 3153f3b..9823bed 100644 --- a/mqrstt/src/packets/unsuback/mod.rs +++ b/mqrstt/src/packets/unsuback/mod.rs @@ -12,7 +12,7 @@ use tokio::io::AsyncReadExt; use super::error::SerializeError; use super::mqtt_trait::{MqttRead, MqttWrite, PacketRead, PacketWrite}; -use super::PacketAsyncRead; +use super::{PacketAsyncRead, VariableInteger, WireLength}; /// UnsubAck packet is sent by the server in response to an [`crate::packets::Unsubscribe`] packet. #[derive(Debug, Default, PartialEq, Eq, Clone)] @@ -23,18 +23,20 @@ pub struct UnsubAck { } impl PacketRead for UnsubAck { - fn read(_: u8, _: usize, mut buf: bytes::Bytes) -> Result { + fn read(_: u8, remaining_length: usize, mut buf: bytes::Bytes) -> Result { let packet_identifier = u16::read(&mut buf)?; let properties = UnsubAckProperties::read(&mut buf)?; let mut reason_codes = vec![]; - loop { - let reason_code = UnsubAckReasonCode::read(&mut buf)?; - reason_codes.push(reason_code); - - if buf.is_empty() { + let mut read = 2 + properties.wire_len().variable_integer_len() + properties.wire_len(); + loop { + if read == remaining_length { break; } + + let reason_code = UnsubAckReasonCode::read(&mut buf)?; + reason_codes.push(reason_code); + read += 1; } Ok(Self { @@ -50,23 +52,22 @@ where S: tokio::io::AsyncRead + Unpin, { async fn async_read(_: u8, remaining_length: usize, stream: &mut S) -> Result<(Self, usize), crate::packets::error::ReadError> { - let mut total_read_bytes = 0; + let mut total_read_bytes = 2; let packet_identifier = stream.read_u16().await?; - total_read_bytes += 2; let (properties, properties_read_bytes) = UnsubAckProperties::async_read(stream).await?; total_read_bytes += properties_read_bytes; let mut reason_codes = vec![]; loop { + if total_read_bytes >= remaining_length { + break; + } + let (reason_code, reason_code_read_bytes) = UnsubAckReasonCode::async_read(stream).await?; total_read_bytes += reason_code_read_bytes; reason_codes.push(reason_code); - - if total_read_bytes >= remaining_length { - break; - } } Ok(( @@ -114,6 +115,12 @@ where } } +impl WireLength for UnsubAck { + fn wire_len(&self) -> usize { + 2 + self.properties.wire_len().variable_integer_len() + self.properties.wire_len() + self.reason_codes.len() + } +} + #[cfg(test)] mod tests { use bytes::{Bytes, BytesMut}; diff --git a/mqrstt/src/packets/unsuback/properties.rs b/mqrstt/src/packets/unsuback/properties.rs index 85dcab0..d462c50 100644 --- a/mqrstt/src/packets/unsuback/properties.rs +++ b/mqrstt/src/packets/unsuback/properties.rs @@ -38,7 +38,7 @@ impl MqttRead for UnsubAckProperties { e => return Err(DeserializeError::UnexpectedProperty(e, PacketType::UnsubAck)), } - if buf.is_empty() { + if properties_data.is_empty() { break; } } diff --git a/mqrstt/src/tests/test_packets.rs b/mqrstt/src/tests/test_packets.rs index c8351e5..2f9dac1 100644 --- a/mqrstt/src/tests/test_packets.rs +++ b/mqrstt/src/tests/test_packets.rs @@ -276,6 +276,60 @@ pub fn create_disconnect_packet() -> Packet { }) } +pub fn suback_case() -> Packet { + let expected = SubAck { + packet_identifier: 3, + reason_codes: vec![SubAckReasonCode::GrantedQoS0, SubAckReasonCode::GrantedQoS1, SubAckReasonCode::GrantedQoS2], + properties: SubAckProperties { + user_properties: vec![(String::from("test").into(), String::from("test").into())], + subscription_identifier: Some(2000), + }, + }; + + Packet::SubAck(expected) +} + +pub fn subscribe_case() -> Packet { + let expected = Subscribe { + packet_identifier: 3, + topics: vec![("test/topic".into(), SubscriptionOptions::default())], + properties: SubscribeProperties { + user_properties: vec![(String::from("test").into(), String::from("test").into())], + subscription_identifier: Some(2000), + }, + }; + + Packet::Subscribe(expected) +} + +// return a crazy big packet +pub fn unsuback_case() -> Packet { + let expected = UnsubAck { + packet_identifier: 3, + reason_codes: vec![ + UnsubAckReasonCode::NoSubscriptionExisted, + UnsubAckReasonCode::UnspecifiedError, + UnsubAckReasonCode::ImplementationSpecificError, + ], + properties: UnsubAckProperties { + user_properties: vec![], + reason_string: None, + }, + }; + + Packet::UnsubAck(expected) +} + +pub fn unsubscribe_case() -> Packet { + let expected = Unsubscribe { + packet_identifier: 3, + topics: vec!["test/topic".into()], + properties: UnsubscribeProperties { user_properties: vec![] }, + }; + + Packet::Unsubscribe(expected) +} + #[rstest] #[case(create_subscribe_packet(1))] #[case(create_subscribe_packet(65335))]