diff --git a/Cargo.toml b/Cargo.toml index 3bd0553..8054f47 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,8 @@ license = "MPL-2.0" keywords = ["MQTT", "IoT", "MQTTv5", "messaging", "client"] description = "Pure rust MQTTv5 client implementation Smol and Tokio" +rust-version = "1.75" + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] diff --git a/src/packets/connack.rs b/src/packets/connack.rs index 19e9ca3..7c2e0b6 100644 --- a/src/packets/connack.rs +++ b/src/packets/connack.rs @@ -1,9 +1,9 @@ use super::{ - error::DeserializeError, - mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead}, + error::{DeserializeError, SerializeError}, + mqtt_traits::{MqttRead, MqttWrite, VariableHeaderRead, WireLength, VariableHeaderWrite}, read_variable_integer, reason_codes::ConnAckReasonCode, - PacketType, PropertyType, QoS, + PacketType, PropertyType, QoS, write_variable_integer, variable_integer_len, }; use bytes::{Buf, BufMut, Bytes}; @@ -38,6 +38,25 @@ impl VariableHeaderRead for ConnAck { } } +impl VariableHeaderWrite for ConnAck { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), SerializeError> { + + self.connack_flags.write(buf)?; + self.reason_code.write(buf)?; + self.connack_properties.write(buf)?; + + Ok(()) + } +} + +impl WireLength for ConnAck { + fn wire_len(&self) -> usize { + 2 + // 1 for connack_flags and 1 for reason_code + variable_integer_len(self.connack_properties.wire_len()) + + self.connack_properties.wire_len() + } +} + #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ConnAckProperties { /// 3.2.2.3.2 Session Expiry Interval @@ -106,6 +125,8 @@ pub struct ConnAckProperties { /// 3.2.2.3.18 Authentication Data /// 22 (0x16) Byte, Identifier of the Authentication Data + // There is a small inconsistency here with authentication_data in the connect packet. + // This is Option while that type uses just Bytes. pub authentication_data: Option, } @@ -235,6 +256,171 @@ impl MqttRead for ConnAckProperties { } } +impl MqttWrite for ConnAckProperties { + fn write(&self, buf: &mut bytes::BytesMut) -> Result<(), super::error::SerializeError> { + write_variable_integer(buf, self.wire_len())?; + + let Self { + session_expiry_interval, + receive_maximum, + maximum_qos, + retain_available, + maximum_packet_size, + assigned_client_id, + topic_alias_maximum, + reason_string, + user_properties, + wildcards_available, + subscription_ids_available, + shared_subscription_available, + server_keep_alive, + response_info, + server_reference, + authentication_method, + authentication_data, + } = self; + + + + if let Some(session_expiry_interval) = session_expiry_interval { + PropertyType::SessionExpiryInterval.write(buf)?; + buf.put_u32(*session_expiry_interval); + } + if let Some(receive_maximum) = receive_maximum { + PropertyType::ReceiveMaximum.write(buf)?; + buf.put_u16(*receive_maximum); + } + if let Some(maximum_qos) = maximum_qos { + PropertyType::MaximumQos.write(buf)?; + maximum_qos.write(buf)?; + } + if let Some(retain_available) = retain_available{ + PropertyType::RetainAvailable.write(buf)?; + retain_available.write(buf)?; + } + if let Some(maximum_packet_size) = maximum_packet_size { + PropertyType::MaximumPacketSize.write(buf)?; + buf.put_u32(*maximum_packet_size); + } + if let Some(client_id) = assigned_client_id { + PropertyType::AssignedClientIdentifier.write(buf)?; + client_id.write(buf)?; + } + if let Some(topic_alias_maximum) = topic_alias_maximum { + PropertyType::TopicAliasMaximum.write(buf)?; + buf.put_u16(*topic_alias_maximum); + } + if let Some(reason_string) = reason_string { + PropertyType::ReasonString.write(buf)?; + reason_string.write(buf)?; + } + for (key, val) in user_properties.iter() { + PropertyType::UserProperty.write(buf)?; + key.write(buf)?; + val.write(buf)?; + } + if let Some(wildcards_available) = wildcards_available{ + PropertyType::WildcardSubscriptionAvailable.write(buf)?; + wildcards_available.write(buf)?; + } + if let Some(subscription_ids_available) = subscription_ids_available{ + PropertyType::SubscriptionIdentifierAvailable.write(buf)?; + subscription_ids_available.write(buf)?; + } + if let Some(shared_subscription_available) = shared_subscription_available{ + PropertyType::SharedSubscriptionAvailable.write(buf)?; + shared_subscription_available.write(buf)?; + } + if let Some(server_keep_alive) = server_keep_alive{ + PropertyType::ServerKeepAlive.write(buf)?; + server_keep_alive.write(buf)?; + } + if let Some(response_info) = response_info { + PropertyType::ResponseInformation.write(buf)?; + response_info.write(buf)?; + } + if let Some(server_reference) = server_reference { + PropertyType::ServerReference.write(buf)?; + server_reference.write(buf)?; + } + if let Some(authentication_method) = &authentication_method { + PropertyType::AuthenticationMethod.write(buf)?; + authentication_method.write(buf)?; + } + if let Some(authentication_data) = authentication_data { + if authentication_method.is_none() { + return Err(SerializeError::AuthDataWithoutAuthMethod); + } + PropertyType::AuthenticationData.write(buf)?; + authentication_data.write(buf)?; + } + + Ok(()) + } +} + +impl WireLength for ConnAckProperties{ + fn wire_len(&self) -> usize { + let mut len: usize = 0; + + if self.session_expiry_interval.is_some() { + len += 1 + 4; + } + if self.receive_maximum.is_some() { + len += 1 + 2; + } + if self.maximum_qos.is_some() { + len += 1 + 1; + } + if self.retain_available.is_some() { + len += 1 + 1; + } + if self.maximum_packet_size.is_some() { + len += 1 + 4; + } + if let Some(client_id) = &self.assigned_client_id { + len += 1 + client_id.wire_len(); + } + if self.topic_alias_maximum.is_some() { + len += 1 + 2; + } + if let Some(reason_string) = &self.reason_string { + len += 1 + reason_string.wire_len(); + } + for (key, value) in &self.user_properties { + len += 1; + len += key.wire_len(); + len += value.wire_len(); + } + if self.wildcards_available.is_some() { + len += 1 + 1; + } + if self.subscription_ids_available.is_some() { + len += 1 + 1; + } + if self.shared_subscription_available.is_some() { + len += 1 + 1; + } + if self.server_keep_alive.is_some() { + len += 1 + 2; + } + if let Some(response_info) = &self.response_info { + len += 1 + response_info.wire_len(); + } + if let Some(server_reference) = &self.server_reference { + len += 1 + server_reference.wire_len(); + } + if let Some(authentication_method) = &self.authentication_method { + len += 1 + authentication_method.wire_len(); + } + if self.authentication_data.is_some() && self.authentication_method.is_some() { + len += 1 + self.authentication_data.as_ref().map(WireLength::wire_len).unwrap_or(0); + } + + len + } +} + #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct ConnAckFlags { pub session_present: bool, @@ -249,7 +435,7 @@ impl MqttRead for ConnAckFlags { let byte = buf.get_u8(); Ok(Self { - session_present: (byte & 0b00000001) != 0, + session_present: (byte & 0b00000001) == 0b00000001, }) } } @@ -265,49 +451,88 @@ impl MqttWrite for ConnAckFlags { #[cfg(test)] mod tests { + use bytes::Buf; + use crate::packets::{ connack::{ConnAck, ConnAckProperties}, - mqtt_traits::{MqttRead, VariableHeaderRead}, - reason_codes::ConnAckReasonCode, + mqtt_traits::{MqttRead, VariableHeaderRead, MqttWrite, VariableHeaderWrite}, + reason_codes::ConnAckReasonCode, Packet, FixedHeader, PacketType, }; #[test] - fn read_connack() { + fn read_write_connack_packet() { + let c = ConnAck{ + ..Default::default() + }; + + let p1 = Packet::ConnAck(c); + let mut buf = bytes::BytesMut::new(); + + p1.write(&mut buf).unwrap(); + + let p2 = Packet::read_from_buffer(&mut buf).unwrap(); + + assert_eq!(p1, p2); + } + + #[test] + fn read_write_connack() { let mut buf = bytes::BytesMut::new(); let packet = &[ 0x01, // Connack flags 0x00, // Reason code, 0x00, // empty properties ]; - + buf.extend_from_slice(packet); - let c = ConnAck::read(0, packet.len(), buf.into()).unwrap(); + let c1 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); + + assert_eq!(ConnAckReasonCode::Success, c1.reason_code); + assert_eq!(ConnAckProperties::default(), c1.connack_properties); + + let mut buf = bytes::BytesMut::new(); + + c1.write(&mut buf).unwrap(); + + let c2 = ConnAck::read(0, packet.len(), buf.into()).unwrap(); + + assert_eq!(c1, c2) - assert_eq!(ConnAckReasonCode::Success, c.reason_code); - assert_eq!(ConnAckProperties::default(), c.connack_properties); } + #[test] - fn read_connack_properties() { + fn read_write_connack_properties() { let mut buf = bytes::BytesMut::new(); let packet = &[ 56, // ConnAckProperties variable length 17, // session_expiry_interval - 0xff, 0xff, 37, // retain_available + 0xff, 0xff, + 37, // retain_available 0x1, // true 18, // Assigned Client Id - 0, 11, // 11 bytes - b'K', // Keanu Reeves without space - b'e', b'a', b'n', b'u', b'R', b'e', b'e', b'v', b'e', b's', 36, // Max QoS + 0, 11, // 11 bytes "KeanuReeves" without space + b'K', b'e', b'a', b'n', b'u', b'R', b'e', b'e', b'v', b'e', b's', + 36, // Max QoS 2, // QoS 2 Exactly Once 34, // Topic Alias Max = 255 - 0, 255, 31, // Reason String = 'Houston we have got a problem' + 0, 255, + 31, // Reason String = 'Houston we have got a problem' 0, 29, b'H', b'o', b'u', b's', b't', b'o', b'n', b' ', b'w', b'e', b' ', b'h', b'a', b'v', b'e', b' ', b'g', b'o', b't', b' ', b'a', b' ', b'p', b'r', b'o', b'b', b'l', b'e', b'm', ]; buf.extend_from_slice(packet); - let c = ConnAckProperties::read(&mut buf.into()).unwrap(); + let c1 = ConnAckProperties::read(&mut buf.into()).unwrap(); + + let mut buf = bytes::BytesMut::new(); - dbg!(c); + c1.write(&mut buf).unwrap(); + + let buf_clone = buf.to_vec(); + + let c2 = ConnAckProperties::read(&mut buf.into()).unwrap(); + + assert_eq!(c1, c2); } + } diff --git a/src/packets/mod.rs b/src/packets/mod.rs index 8588385..1b75298 100644 --- a/src/packets/mod.rs +++ b/src/packets/mod.rs @@ -91,6 +91,7 @@ impl QoS { } impl MqttRead for QoS { + #[inline] fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { return Err(DeserializeError::InsufficientData("QoS".to_string(), 0, 1)); @@ -106,6 +107,7 @@ impl MqttRead for QoS { } impl MqttWrite for QoS { + #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { let val = match self { QoS::AtMostOnce => 0, @@ -118,6 +120,7 @@ impl MqttWrite for QoS { } impl MqttRead for Box { + #[inline] fn read(buf: &mut Bytes) -> Result { let content = Bytes::read(buf)?; @@ -143,6 +146,7 @@ impl WireLength for Box { } impl MqttWrite for &str { + #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { buf.put_u16(self.len() as u16); buf.extend(self.as_bytes()); @@ -158,6 +162,7 @@ impl WireLength for &str { } impl MqttRead for String { + #[inline] fn read(buf: &mut Bytes) -> Result { let content = Bytes::read(buf)?; @@ -169,6 +174,7 @@ impl MqttRead for String { } impl MqttWrite for String { + #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { if self.len() > 65535 { return Err(SerializeError::StringTooLong(self.len())); @@ -188,6 +194,7 @@ impl WireLength for String { } impl MqttRead for Bytes { + #[inline] fn read(buf: &mut Bytes) -> Result { let len = buf.get_u16() as usize; @@ -200,6 +207,7 @@ impl MqttRead for Bytes { } impl MqttWrite for Bytes { + #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { buf.put_u16(self.len() as u16); buf.extend(self); @@ -230,6 +238,7 @@ impl MqttRead for bool { } impl MqttWrite for bool { + #[inline] fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { if *self { buf.put_u8(1); @@ -242,6 +251,7 @@ impl MqttWrite for bool { } impl MqttRead for u8 { + #[inline] fn read(buf: &mut Bytes) -> Result { if buf.is_empty() { return Err(DeserializeError::InsufficientData("u8".to_string(), 0, 1)); @@ -251,6 +261,7 @@ impl MqttRead for u8 { } impl MqttRead for u16 { + #[inline] fn read(buf: &mut Bytes) -> Result { if buf.len() < 2 { return Err(DeserializeError::InsufficientData("u16".to_string(), buf.len(), 2)); @@ -259,7 +270,16 @@ impl MqttRead for u16 { } } +impl MqttWrite for u16 { + #[inline] + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u16(*self); + Ok(()) + } +} + impl MqttRead for u32 { + #[inline] fn read(buf: &mut Bytes) -> Result { if buf.len() < 4 { return Err(DeserializeError::InsufficientData("u32".to_string(), buf.len(), 4)); @@ -268,6 +288,13 @@ impl MqttRead for u32 { } } +impl MqttWrite for u32 { + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + buf.put_u32(*self); + Ok(()) + } +} + pub fn read_fixed_header_rem_len(mut buf: Iter) -> Result<(usize, usize), ReadBytes> { let mut integer = 0; let mut length = 0; @@ -563,8 +590,10 @@ impl Packet { p.write(buf)?; } - Packet::ConnAck(_) => { - unreachable!() + Packet::ConnAck(p) => { + buf.put_u8(0b0010_0000); + write_variable_integer(buf, p.wire_len())?; + p.write(buf)?; } Packet::Publish(p) => { let mut first_byte = 0b0011_0000u8; @@ -615,8 +644,8 @@ impl Packet { p.write(buf)?; } Packet::UnsubAck(_) => { + unreachable!(); buf.put_u8(0b1011_0000); - unreachable!() } Packet::PingReq => { buf.put_u8(0b1100_0000); diff --git a/src/packets/mqtt_traits.rs b/src/packets/mqtt_traits.rs index cd20f0f..502366d 100644 --- a/src/packets/mqtt_traits.rs +++ b/src/packets/mqtt_traits.rs @@ -22,6 +22,15 @@ pub trait MqttWrite: Sized { fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError>; } +impl MqttWrite for &T +where + T: MqttWrite +{ + fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> { + ::write(self, buf) + } +} + pub trait PacketValidation: Sized { fn validate(&self, max_packet_size: usize) -> Result<(), crate::error::PacketValidationError>; }