diff --git a/Cargo.toml b/Cargo.toml index e1cd889..a21fe7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,12 @@ default = ["std"] # Implements serde::{Serialize,Deserialize} on mqttrs::Pid. derive = ["serde"] -std = ["bytes/std", "serde/std"] +std = ["bytes", "bytes/std", "serde/std"] [dependencies] -bytes = { version = "0.5", default-features = false } +bytes = { version = "0.5", default-features = false, optional = true } serde = { version = "1.0", features = ["derive"], optional = true } +heapless = "0.5.5" [dev-dependencies] -proptest = "0.9.4" +proptest = "0.10.0" diff --git a/src/codec_test.rs b/src/codec_test.rs index ae0285b..f271997 100644 --- a/src/codec_test.rs +++ b/src/codec_test.rs @@ -2,8 +2,6 @@ use crate::*; use bytes::BytesMut; use proptest::{bool, collection::vec, num::*, prelude::*}; use core::convert::TryFrom; -use alloc::string::String; -use alloc::format; // Proptest strategies to generate packet elements prop_compose! { @@ -172,19 +170,19 @@ macro_rules! impl_proptests { // modified to accept other buffer types. // Check that encoding into a small buffer fails cleanly - //buf.clear(); - //buf.split_off(encoded.len()); - //prop_assert!(encoded.len() == buf.remaining_mut() && buf.is_empty(), + // buf.clear(); + // buf.split_off(encoded.len()); + // prop_assert!(encoded.len() == buf.remaining_mut() && buf.is_empty(), // "Wrong buffer init1 {}/{}/{}", encoded.len(), buf.remaining_mut(), buf.is_empty()); - //prop_assert!(encode(&pkt, &mut buf).is_ok(), "exact buffer capacity {}", buf.capacity()); - //for l in (0..encoded.len()).rev() { + // prop_assert!(encode(&pkt, &mut buf).is_ok(), "exact buffer capacity {}", buf.capacity()); + // for l in (0..encoded.len()).rev() { // buf.clear(); // buf.split_to(1); // prop_assert!(l == buf.remaining_mut() && buf.is_empty(), // "Wrong buffer init2 {}/{}/{}", l, buf.remaining_mut(), buf.is_empty()); // prop_assert_eq!(Err(Error::WriteZero), encode(&pkt, &mut buf), // "small buffer capacity {}/{}", buf.capacity(), encoded.len()); - //} + // } } } }; diff --git a/src/connect.rs b/src/connect.rs index 784ea93..ef627e8 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,6 +1,4 @@ use crate::{decoder::*, encoder::*, *}; -use alloc::{string::String, vec::Vec}; -use bytes::{Buf, BufMut, BytesMut}; /// Protocol version. /// @@ -28,18 +26,29 @@ impl Protocol { _ => Err(Error::InvalidProtocol(name.into(), level)), } } - pub(crate) fn to_buffer(&self, buf: &mut impl BufMut) -> Result { + pub(crate) fn from_buffer<'a>(buf: &'a [u8], offset: &mut usize) -> Result { + let protocol_name = read_str(buf, offset)?; + let protocol_level = buf[*offset]; + *offset += 1; + + Protocol::new(protocol_name, protocol_level) + } + pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { match self { Protocol::MQTT311 => { let slice = &[0u8, 4, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 4]; - buf.put_slice(slice); + for &byte in slice { + write_u8(buf, offset, byte)?; + } Ok(slice.len()) } Protocol::MQIsdp => { let slice = &[ 0u8, 4, 'M' as u8, 'Q' as u8, 'i' as u8, 's' as u8, 'd' as u8, 'p' as u8, 4, ]; - buf.put_slice(slice); + for &byte in slice { + write_u8(buf, offset, byte)?; + } Ok(slice.len()) } } @@ -53,9 +62,9 @@ impl Protocol { /// [Connect]: struct.Connect.html /// [MQTT 3.1.3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031 #[derive(Debug, Clone, PartialEq)] -pub struct LastWill { - pub topic: String, - pub message: Vec, +pub struct LastWill<'a> { + pub topic: &'a str, + pub message: &'a [u8], pub qos: QoS, pub retain: bool, } @@ -103,14 +112,14 @@ impl ConnectReturnCode { /// /// [MQTT 3.1]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028 #[derive(Debug, Clone, PartialEq)] -pub struct Connect { +pub struct Connect<'a> { pub protocol: Protocol, pub keep_alive: u16, - pub client_id: String, + pub client_id: &'a str, pub clean_session: bool, - pub last_will: Option, - pub username: Option, - pub password: Option>, + pub last_will: Option>, + pub username: Option<&'a str>, + pub password: Option<&'a [u8]>, } /// Connack packet ([MQTT 3.2]). @@ -122,20 +131,19 @@ pub struct Connack { pub code: ConnectReturnCode, } -impl Connect { - pub(crate) fn from_buffer(buf: &mut BytesMut) -> Result { - let protocol_name = read_string(buf)?; - let protocol_level = buf.get_u8(); - let protocol = Protocol::new(&protocol_name, protocol_level)?; +impl<'a> Connect<'a> { + pub(crate) fn from_buffer(buf: &'a [u8], offset: &mut usize) -> Result { + let protocol = Protocol::from_buffer(buf, offset)?; - let connect_flags = buf.get_u8(); - let keep_alive = buf.get_u16(); + let connect_flags = buf[*offset]; + let keep_alive = ((buf[*offset + 1] as u16) << 8) | buf[*offset + 2] as u16; + *offset += 3; - let client_id = read_string(buf)?; + let client_id = read_str(buf, offset)?; let last_will = if connect_flags & 0b100 != 0 { - let will_topic = read_string(buf)?; - let will_message = read_bytes(buf)?; + let will_topic = read_str(buf, offset)?; + let will_message = read_bytes(buf, offset)?; let will_qod = QoS::from_u8((connect_flags & 0b11000) >> 3)?; Some(LastWill { topic: will_topic, @@ -148,13 +156,13 @@ impl Connect { }; let username = if connect_flags & 0b10000000 != 0 { - Some(read_string(buf)?) + Some(read_str(buf, offset)?) } else { None }; let password = if connect_flags & 0b01000000 != 0 { - Some(read_bytes(buf)?) + Some(read_bytes(buf, offset)?) } else { None }; @@ -171,7 +179,8 @@ impl Connect { clean_session, }) } - pub(crate) fn to_buffer(&self, buf: &mut impl BufMut) -> Result { + + pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { let header: u8 = 0b00010000; let mut length: usize = 6 + 1 + 1; // NOTE: protocol_name(6) + protocol_level(1) + flags(1); let mut connect_flags: u8 = 0b00000000; @@ -180,12 +189,12 @@ impl Connect { }; length += 2 + self.client_id.len(); length += 2; // keep alive - if let Some(username) = &self.username { + if let Some(username) = self.username { connect_flags |= 0b10000000; length += username.len(); length += 2; }; - if let Some(password) = &self.password { + if let Some(password) = self.password { connect_flags |= 0b01000000; length += password.len(); length += 2; @@ -200,26 +209,29 @@ impl Connect { length += last_will.topic.len(); length += 4; }; - check_remaining(buf, length + 1)?; + check_remaining(buf, offset, length + 1)?; // NOTE: putting data into buffer. - buf.put_u8(header); - let write_len = write_length(length, buf)? + 1; - self.protocol.to_buffer(buf)?; - buf.put_u8(connect_flags); - buf.put_u16(self.keep_alive); - write_string(self.client_id.as_ref(), buf)?; + write_u8(buf, offset, header)?; + + let write_len = write_length(buf, offset, length)? + 1; + self.protocol.to_buffer(buf, offset)?; + + write_u8(buf, offset, connect_flags)?; + write_u16(buf, offset, self.keep_alive)?; + + write_string(buf, offset, self.client_id)?; if let Some(last_will) = &self.last_will { - write_string(last_will.topic.as_ref(), buf)?; - write_bytes(&last_will.message, buf)?; + write_string(buf, offset, last_will.topic)?; + write_bytes(buf, offset, &last_will.message)?; }; - if let Some(username) = &self.username { - write_string(username.as_ref(), buf)?; + if let Some(username) = self.username { + write_string(buf, offset, username)?; }; - if let Some(password) = &self.password { - write_bytes(password, buf)?; + if let Some(password) = self.password { + write_bytes(buf, offset, password)?; }; // NOTE: END Ok(write_len) @@ -227,16 +239,17 @@ impl Connect { } impl Connack { - pub(crate) fn from_buffer(buf: &mut impl Buf) -> Result { - let flags = buf.get_u8(); - let return_code = buf.get_u8(); + pub(crate) fn from_buffer<'a>(buf: &'a [u8], offset: &mut usize) -> Result { + let flags = buf[*offset]; + let return_code = buf[*offset + 1]; + *offset += 2; Ok(Connack { session_present: (flags & 0b1 == 1), code: ConnectReturnCode::from_u8(return_code)?, }) } - pub(crate) fn to_buffer(&self, buf: &mut impl BufMut) -> Result { - check_remaining(buf, 4)?; + pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { + check_remaining(buf, offset, 4)?; let header: u8 = 0b00100000; let length: u8 = 2; let mut flags: u8 = 0b00000000; @@ -244,10 +257,10 @@ impl Connack { flags |= 0b1; }; let rc = self.code.to_u8(); - buf.put_u8(header); - buf.put_u8(length); - buf.put_u8(flags); - buf.put_u8(rc); + write_u8(buf, offset, header)?; + write_u8(buf, offset, length)?; + write_u8(buf, offset, flags)?; + write_u8(buf, offset, rc)?; Ok(4) } } diff --git a/src/decoder.rs b/src/decoder.rs index 0ba8932..06f4b33 100644 --- a/src/decoder.rs +++ b/src/decoder.rs @@ -1,6 +1,8 @@ use crate::*; -use alloc::{string::String, vec::Vec}; -use bytes::{Buf, BytesMut}; + +// use alloc::{string::String, vec::Vec}; +use heapless::{String, Vec, ArrayLength}; + /// Decode bytes from a [BytesMut] buffer as a [Packet] enum. /// @@ -16,9 +18,9 @@ use bytes::{Buf, BytesMut}; /// 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8] as &[u8]); /// /// // Parse the bytes and check the result. -/// match decode(&mut buf) { +/// match decode_slice(&mut buf) { /// Ok(Some(Packet::Publish(p))) => { -/// assert_eq!(p.payload, "hello".as_bytes().to_vec()); +/// assert_eq!(p.payload, b"hello"); /// }, /// // In real code you probably don't want to panic like that ;) /// Ok(None) => panic!("not enough data"), @@ -28,51 +30,87 @@ use bytes::{Buf, BytesMut}; /// /// [Packet]: ../enum.Packet.html /// [BytesMut]: https://docs.rs/bytes/0.5.3/bytes/struct.BytesMut.html -pub fn decode(buf: &mut BytesMut) -> Result, Error> { - if let Some((header, remaining_len)) = read_header(buf)? { - // Advance the buffer position to the next packet, and parse the current packet - Ok(Some(read_packet(header, &mut buf.split_to(remaining_len))?)) + +pub fn clone_packet(input: &[u8], output: &mut [u8]) -> Result { + if input.is_empty() { + return Ok(0); + } + + let mut offset = 0; + // while Header::new(input[offset]).is_err() { + // offset += 1; + // if input[offset..].is_empty() { + // return Ok(0); + // } + // } + + let start = offset; + if let Some((_, remaining_len)) = read_header(input, &mut offset)? { + let end = offset + remaining_len; + let len = end - start; + output[..len].copy_from_slice(&input[start..end]); + Ok(len) + } else { + // Don't have a full packet + Ok(0) + } +} + +pub fn decode_slice<'a>(buf: &'a [u8]) -> Result>, Error> { + let mut offset = 0; + if let Some((header, remaining_len)) = read_header(buf, &mut offset)? { + let r = read_packet(header, remaining_len, buf, &mut offset)?; + Ok(Some(r)) } else { // Don't have a full packet Ok(None) } } -fn read_packet(header: Header, buf: &mut BytesMut) -> Result { +fn read_packet<'a>( + header: Header, + remaining_len: usize, + buf: &'a [u8], + offset: &mut usize, +) -> Result, Error> { Ok(match header.typ { PacketType::Pingreq => Packet::Pingreq, PacketType::Pingresp => Packet::Pingresp, PacketType::Disconnect => Packet::Disconnect, - PacketType::Connect => Connect::from_buffer(buf)?.into(), - PacketType::Connack => Connack::from_buffer(buf)?.into(), - PacketType::Publish => Publish::from_buffer(&header, buf)?.into(), - PacketType::Puback => Packet::Puback(Pid::from_buffer(buf)?), - PacketType::Pubrec => Packet::Pubrec(Pid::from_buffer(buf)?), - PacketType::Pubrel => Packet::Pubrel(Pid::from_buffer(buf)?), - PacketType::Pubcomp => Packet::Pubcomp(Pid::from_buffer(buf)?), - PacketType::Subscribe => Subscribe::from_buffer(buf)?.into(), - PacketType::Suback => Suback::from_buffer(buf)?.into(), - PacketType::Unsubscribe => Unsubscribe::from_buffer(buf)?.into(), - PacketType::Unsuback => Packet::Unsuback(Pid::from_buffer(buf)?), + PacketType::Connect => Connect::from_buffer(buf, offset)?.into(), + PacketType::Connack => Connack::from_buffer(buf, offset)?.into(), + PacketType::Publish => Publish::from_buffer(&header, remaining_len, buf, offset)?.into(), + PacketType::Puback => Packet::Puback(Pid::from_buffer(buf, offset)?), + PacketType::Pubrec => Packet::Pubrec(Pid::from_buffer(buf, offset)?), + PacketType::Pubrel => Packet::Pubrel(Pid::from_buffer(buf, offset)?), + PacketType::Pubcomp => Packet::Pubcomp(Pid::from_buffer(buf, offset)?), + PacketType::Subscribe => Subscribe::from_buffer(remaining_len, buf, offset)?.into(), + PacketType::Suback => Suback::from_buffer(remaining_len, buf, offset)?.into(), + PacketType::Unsubscribe => Unsubscribe::from_buffer(remaining_len, buf, offset)?.into(), + PacketType::Unsuback => Packet::Unsuback(Pid::from_buffer(buf, offset)?), }) } /// Read the parsed header and remaining_len from the buffer. Only return Some() and advance the /// buffer position if there is enough data in the buffer to read the full packet. -fn read_header(buf: &mut BytesMut) -> Result, Error> { +pub(crate) fn read_header<'a>( + buf: &'a [u8], + offset: &mut usize, +) -> Result, Error> { let mut len: usize = 0; for pos in 0..=3 { - if let Some(&byte) = buf.get(pos + 1) { + if buf.len() > *offset + pos + 1 { + let byte = buf[*offset + pos + 1]; len += (byte as usize & 0x7F) << (pos * 7); if (byte & 0x80) == 0 { // Continuation bit == 0, length is parsed - if buf.remaining() < 2 + pos + len { + if buf.len() < *offset + 2 + pos + len { // Won't be able to read full packet return Ok(None); } // Parse header byte, skip past the header, and return - let header = Header::new(buf.get_u8())?; - buf.advance(pos + 1); + let header = Header::new(buf[*offset])?; + *offset += pos + 2; return Ok(Some((header, len))); } } else { @@ -122,127 +160,21 @@ impl Header { } } -pub(crate) fn read_string(buf: &mut BytesMut) -> Result { - String::from_utf8(read_bytes(buf)?).map_err(|e| Error::InvalidString(e.utf8_error())) +pub(crate) fn read_str<'a>(buf: &'a [u8], offset: &mut usize) -> Result<&'a str, Error> { + core::str::from_utf8(read_bytes(buf, offset)?).map_err(|e| Error::InvalidString(e)) } -pub(crate) fn read_bytes(buf: &mut BytesMut) -> Result, Error> { - let len = buf.get_u16() as usize; - if len > buf.remaining() { +pub(crate) fn read_bytes<'a>(buf: &'a [u8], offset: &mut usize) -> Result<&'a [u8], Error> { + if buf[*offset..].len() < 2 { + return Err(Error::InvalidLength); + } + let len = ((buf[*offset] as usize) << 8) | buf[*offset + 1] as usize; + *offset += 2; + if len > buf[*offset..].len() { Err(Error::InvalidLength) } else { - Ok(buf.split_to(len).to_vec()) - } -} - -#[cfg(test)] -mod test { - use crate::decoder::*; - use alloc::vec; - - macro_rules! header { - ($t:ident, $d:expr, $q:ident, $r:expr) => { - Header { - typ: PacketType::$t, - dup: $d, - qos: QoS::$q, - retain: $r, - } - }; - } - - fn bm(d: &[u8]) -> BytesMut { - BytesMut::from(d) - } - - /// Test all possible header first byte, using remaining_len=0. - #[test] - fn header_firstbyte() { - let valid = vec![ - (0b0001_0000, header!(Connect, false, AtMostOnce, false)), - (0b0010_0000, header!(Connack, false, AtMostOnce, false)), - (0b0011_0000, header!(Publish, false, AtMostOnce, false)), - (0b0011_0001, header!(Publish, false, AtMostOnce, true)), - (0b0011_0010, header!(Publish, false, AtLeastOnce, false)), - (0b0011_0011, header!(Publish, false, AtLeastOnce, true)), - (0b0011_0100, header!(Publish, false, ExactlyOnce, false)), - (0b0011_0101, header!(Publish, false, ExactlyOnce, true)), - (0b0011_1000, header!(Publish, true, AtMostOnce, false)), - (0b0011_1001, header!(Publish, true, AtMostOnce, true)), - (0b0011_1010, header!(Publish, true, AtLeastOnce, false)), - (0b0011_1011, header!(Publish, true, AtLeastOnce, true)), - (0b0011_1100, header!(Publish, true, ExactlyOnce, false)), - (0b0011_1101, header!(Publish, true, ExactlyOnce, true)), - (0b0100_0000, header!(Puback, false, AtMostOnce, false)), - (0b0101_0000, header!(Pubrec, false, AtMostOnce, false)), - (0b0110_0010, header!(Pubrel, false, AtLeastOnce, false)), - (0b0111_0000, header!(Pubcomp, false, AtMostOnce, false)), - (0b1000_0010, header!(Subscribe, false, AtLeastOnce, false)), - (0b1001_0000, header!(Suback, false, AtMostOnce, false)), - (0b1010_0010, header!(Unsubscribe, false, AtLeastOnce, false)), - (0b1011_0000, header!(Unsuback, false, AtMostOnce, false)), - (0b1100_0000, header!(Pingreq, false, AtMostOnce, false)), - (0b1101_0000, header!(Pingresp, false, AtMostOnce, false)), - (0b1110_0000, header!(Disconnect, false, AtMostOnce, false)), - ]; - for n in 0..=255 { - let res = match valid.iter().find(|(byte, _)| *byte == n) { - Some((_, header)) => Ok(Some((*header, 0))), - None if ((n & 0b110) == 0b110) && (n >> 4 == 3) => Err(Error::InvalidQos(3)), - None => Err(Error::InvalidHeader), - }; - let mut buf = bm(&[n, 0]); - assert_eq!(res, read_header(&mut buf), "{:08b}", n); - } - } - - /// Test decoding of length and actual buffer len. - #[rustfmt::skip] - #[test] - fn header_len() { - let h = header!(Connect, false, AtMostOnce, false); - for (res, mut bytes, buflen) in vec![ - (Ok(Some((h, 0))), vec![1 << 4, 0], 2), - (Ok(None), vec![1 << 4, 127], 128), - (Ok(Some((h, 127))), vec![1 << 4, 127], 129), - (Ok(None), vec![1 << 4, 0x80], 2), - (Ok(Some((h, 0))), vec![1 << 4, 0x80, 0], 3), //Weird encoding for "0" buf matches spec - (Ok(Some((h, 128))), vec![1 << 4, 0x80, 1], 131), - (Ok(None), vec![1 << 4, 0x80+16, 78], 10002), - (Ok(Some((h, 10000))), vec![1 << 4, 0x80+16, 78], 10003), - (Err(Error::InvalidHeader), vec![1 << 4, 0x80, 0x80, 0x80, 0x80], 10), - ] { - bytes.resize(buflen, 0); - let mut buf = bm(bytes.as_slice()); - assert_eq!(res, read_header(&mut buf)); - } - } - - #[test] - fn non_utf8_string() { - let mut data = bm(&[ - 0b00110000, 10, // type=Publish, remaining_len=10 - 0x00, 0x03, 'a' as u8, '/' as u8, 0xc0 as u8, // Topic with Invalid utf8 - 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8, // payload - ]); - assert!(match decode(&mut data) { - Err(Error::InvalidString(_)) => true, - _ => false, - }); - } - - /// Validity of remaining_len is tested exhaustively elsewhere, this is for inner lengths, which - /// are rarer. - #[test] - fn inner_length_too_long() { - let mut data = bm(&[ - 0b00010000, 20, // Connect packet, remaining_len=20 - 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, - 0b01000000, // +password - 0x00, 0x0a, // keepalive 10 sec - 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id - 0x00, 0x03, 'm' as u8, 'q' as u8, // password with invalid length - ]); - assert_eq!(Err(Error::InvalidLength), decode(&mut data)); + let bytes = &buf[*offset..*offset + len]; + *offset += len; + Ok(bytes) } } diff --git a/src/decoder_test.rs b/src/decoder_test.rs index 307f4b2..2172bf3 100644 --- a/src/decoder_test.rs +++ b/src/decoder_test.rs @@ -1,14 +1,143 @@ use crate::*; -use alloc::string::{String, ToString}; use bytes::BytesMut; +use subscribe::LimitedString; + +macro_rules! header { + ($t:ident, $d:expr, $q:ident, $r:expr) => { + decoder::Header { + typ: PacketType::$t, + dup: $d, + qos: QoS::$q, + retain: $r, + } + }; +} fn bm(d: &[u8]) -> BytesMut { BytesMut::from(d) } +/// Test all possible header first byte, using remaining_len=0. #[test] -fn test_half_connect() { +fn header_firstbyte() { + let valid = vec![ + (0b0001_0000, header!(Connect, false, AtMostOnce, false)), + (0b0010_0000, header!(Connack, false, AtMostOnce, false)), + (0b0011_0000, header!(Publish, false, AtMostOnce, false)), + (0b0011_0001, header!(Publish, false, AtMostOnce, true)), + (0b0011_0010, header!(Publish, false, AtLeastOnce, false)), + (0b0011_0011, header!(Publish, false, AtLeastOnce, true)), + (0b0011_0100, header!(Publish, false, ExactlyOnce, false)), + (0b0011_0101, header!(Publish, false, ExactlyOnce, true)), + (0b0011_1000, header!(Publish, true, AtMostOnce, false)), + (0b0011_1001, header!(Publish, true, AtMostOnce, true)), + (0b0011_1010, header!(Publish, true, AtLeastOnce, false)), + (0b0011_1011, header!(Publish, true, AtLeastOnce, true)), + (0b0011_1100, header!(Publish, true, ExactlyOnce, false)), + (0b0011_1101, header!(Publish, true, ExactlyOnce, true)), + (0b0100_0000, header!(Puback, false, AtMostOnce, false)), + (0b0101_0000, header!(Pubrec, false, AtMostOnce, false)), + (0b0110_0010, header!(Pubrel, false, AtLeastOnce, false)), + (0b0111_0000, header!(Pubcomp, false, AtMostOnce, false)), + (0b1000_0010, header!(Subscribe, false, AtLeastOnce, false)), + (0b1001_0000, header!(Suback, false, AtMostOnce, false)), + (0b1010_0010, header!(Unsubscribe, false, AtLeastOnce, false)), + (0b1011_0000, header!(Unsuback, false, AtMostOnce, false)), + (0b1100_0000, header!(Pingreq, false, AtMostOnce, false)), + (0b1101_0000, header!(Pingresp, false, AtMostOnce, false)), + (0b1110_0000, header!(Disconnect, false, AtMostOnce, false)), + ]; + for n in 0..=255 { + let res = match valid.iter().find(|(byte, _)| *byte == n) { + Some((_, header)) => Ok(Some((*header, 0))), + None if ((n & 0b110) == 0b110) && (n >> 4 == 3) => Err(Error::InvalidQos(3)), + None => Err(Error::InvalidHeader), + }; + let mut buf: &[u8] = &[n, 0]; + let mut offset = 0; + assert_eq!( + res, + decoder::read_header(&mut buf, &mut offset), + "{:08b}", + n + ); + if res.is_ok() { + assert_eq!(offset, 2); + } else { + assert_eq!(offset, 0); + } + } +} + +/// Test decoding of length and actual buffer len. +#[rustfmt::skip] +#[test] +fn header_len() { + let h = header!(Connect, false, AtMostOnce, false); + for (res, mut bytes, buflen) in vec![ + (Ok(Some((h, 0))), vec![1 << 4, 0], 2), + (Ok(None), vec![1 << 4, 127], 128), + (Ok(Some((h, 127))), vec![1 << 4, 127], 129), + (Ok(None), vec![1 << 4, 0x80], 2), + (Ok(Some((h, 0))), vec![1 << 4, 0x80, 0], 3), //Weird encoding for "0" buf matches spec + (Ok(Some((h, 128))), vec![1 << 4, 0x80, 1], 131), + (Ok(None), vec![1 << 4, 0x80+16, 78], 10002), + (Ok(Some((h, 10000))), vec![1 << 4, 0x80+16, 78], 10003), + (Err(Error::InvalidHeader), vec![1 << 4, 0x80, 0x80, 0x80, 0x80], 10), + ] { + let offset_expectation = bytes.len(); + bytes.resize(buflen, 0); + let mut slice_buf = bytes.as_slice(); + let mut offset = 0; + assert_eq!(res, decoder::read_header(&mut slice_buf, &mut offset)); + match res { + Ok(Some(_)) => assert_eq!(offset, offset_expectation), + _ => assert_eq!(offset, 0) + } + } +} + +#[test] +fn non_utf8_string() { + let mut data: &[u8] = &[ + 0b00110000, 10, // type=Publish, remaining_len=10 + 0x00, 0x03, 'a' as u8, '/' as u8, 0xc0 as u8, // Topic with Invalid utf8 + 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8, // payload + ]; + assert!(match decode_slice(&mut data) { + Err(Error::InvalidString(_)) => true, + _ => false, + }); +} + +/// Validity of remaining_len is tested exhaustively elsewhere, this is for inner lengths, which +/// are rarer. +#[test] +fn inner_length_too_long() { let mut data = bm(&[ + 0b00010000, 20, // Connect packet, remaining_len=20 + 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, 0b01000000, // +password + 0x00, 0x0a, // keepalive 10 sec + 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id + 0x00, 0x03, 'm' as u8, 'q' as u8, // password with invalid length + ]); + assert_eq!(Err(Error::InvalidLength), decode_slice(&mut data)); + + let mut slice: &[u8] = &[ + 0b00010000, 20, // Connect packet, remaining_len=20 + 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, 0b01000000, // +password + 0x00, 0x0a, // keepalive 10 sec + 0x00, 0x04, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, // client_id + 0x00, 0x03, 'm' as u8, 'q' as u8, // password with invalid length + ]; + + assert_eq!(Err(Error::InvalidLength), decode_slice(&mut slice)); + // assert_eq!(slice, []); +} + +#[test] +fn test_half_connect() { + let mut data: &[u8] = &[ 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session 0x00, @@ -19,14 +148,14 @@ fn test_half_connect() { // 'e' as u8, // will msg = 'offline' // 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' // 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - ]); - assert_eq!(Ok(None), decode(&mut data)); + ]; + assert_eq!(Ok(None), decode_slice(&mut data)); assert_eq!(12, data.len()); } #[test] fn test_connect_wrong_version() { - let mut data = bm(&[ + let mut data: &[u8] = &[ 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x01, 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session 0x00, 0x0a, // 10 sec @@ -36,13 +165,13 @@ fn test_connect_wrong_version() { 'e' as u8, // will msg = 'offline' 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - ]); - assert!(decode(&mut data).is_err(), "Unknown version should return error"); + ]; + assert!(decode_slice(&mut data).is_err(), "Unknown version should return error"); } #[test] fn test_connect() { - let mut data = bm(&[ + let mut data: &[u8] = &[ 0b00010000, 39, 0x00, 0x04, 'M' as u8, 'Q' as u8, 'T' as u8, 'T' as u8, 0x04, 0b11001110, // +username, +password, -will retain, will qos=1, +last_will, +clean_session 0x00, 0x0a, // 10 sec @@ -52,29 +181,32 @@ fn test_connect() { 'e' as u8, // will msg = 'offline' 0x00, 0x04, 'r' as u8, 'u' as u8, 's' as u8, 't' as u8, // username = 'rust' 0x00, 0x02, 'm' as u8, 'q' as u8, // password = 'mq' - ]); + ]; let pkt = Connect { protocol: Protocol::MQTT311, keep_alive: 10, - client_id: "test".into(), + client_id: "test", clean_session: true, last_will: Some(LastWill { - topic: "/a".into(), - message: "offline".into(), + topic: "/a", + message: b"offline", qos: QoS::AtLeastOnce, retain: false, }), - username: Some("rust".into()), - password: Some("mq".into()), + username: Some("rust"), + password: Some(b"mq"), }; - assert_eq!(Ok(Some(pkt.into())), decode(&mut data)); - assert_eq!(data.len(), 0); + + let packet_buf = &mut [0u8; 64]; + assert_eq!(clone_packet(&mut data, &mut packet_buf[..]).unwrap(), 41); + assert_eq!(Ok(Some(pkt.into())), decode_slice(packet_buf)); + // assert_eq!(data.len(), 0); } #[test] fn test_connack() { - let mut data = bm(&[0b00100000, 2, 0b00000000, 0b00000001]); - let d = decoder::decode(&mut data).unwrap(); + let mut data: &[u8] = &[0b00100000, 2, 0b00000000, 0b00000001]; + let d = decode_slice(&mut data).unwrap(); match d { Some(Packet::Connack(c)) => { let o = Connack { @@ -90,60 +222,109 @@ fn test_connack() { #[test] fn test_ping_req() { - let mut data = bm(&[0b11000000, 0b00000000]); - assert_eq!(Ok(Some(Packet::Pingreq)), decode(&mut data)); + let mut data: &[u8] = &[0b11000000, 0b00000000]; + assert_eq!(Ok(Some(Packet::Pingreq)), decode_slice(&mut data)); } #[test] fn test_ping_resp() { - let mut data = bm(&[0b11010000, 0b00000000]); - assert_eq!(Ok(Some(Packet::Pingresp)), decode(&mut data)); + let mut data: &[u8] = &[0b11010000, 0b00000000]; + assert_eq!(Ok(Some(Packet::Pingresp)), decode_slice(&mut data)); } #[test] fn test_disconnect() { - let mut data = bm(&[0b11100000, 0b00000000]); - assert_eq!(Ok(Some(Packet::Disconnect)), decode(&mut data)); + let mut data: &[u8] = &[0b11100000, 0b00000000]; + assert_eq!(Ok(Some(Packet::Disconnect)), decode_slice(&mut data)); +} + +#[test] +#[ignore] +fn test_offset_start() { + let mut data: &[u8] = &[ + 1, 2, 3, 0b00110000, 10, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 'h' as u8, 'e' as u8, + 'l' as u8, 'l' as u8, 'o' as u8, // + 0b00111000, 10, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 'h' as u8, 'e' as u8, + 'l' as u8, 'l' as u8, 'o' as u8, // + 0b00111101, 12, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0, 10, 'h' as u8, 'e' as u8, + 'l' as u8, 'l' as u8, 'o' as u8, + ]; + + let packet_buf = &mut [0u8; 64]; + assert_eq!(clone_packet(&mut data, &mut packet_buf[..]).unwrap(), 12); + assert_eq!(data.len(), 29); + + match decode_slice(packet_buf) { + Ok(Some(Packet::Publish(p))) => { + assert_eq!(p.dup, false); + assert_eq!(p.retain, false); + assert_eq!(p.qospid, QosPid::AtMostOnce); + assert_eq!(p.topic_name, "a/b"); + assert_eq!(core::str::from_utf8(p.payload).unwrap(), "hello"); + } + other => panic!("Failed decode: {:?}", other), + } } #[test] +#[ignore] fn test_publish() { - let mut data = bm(&[ + let mut data: &[u8] = &[ 0b00110000, 10, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8, // 0b00111000, 10, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8, // 0b00111101, 12, 0x00, 0x03, 'a' as u8, '/' as u8, 'b' as u8, 0, 10, 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8, - ]); + ]; - match decode(&mut data) { + let mut offset = 0; + assert_eq!( + decoder::read_header(&data, &mut offset).unwrap(), + Some((decoder::Header::new(0b00110000).unwrap(), 10)) + ); + assert_eq!(data.len(), 38); + + let packet_buf = &mut [0u8; 64]; + assert_eq!(clone_packet(&mut data, &mut packet_buf[..]).unwrap(), 12); + // assert_eq!(data.len(), 26); + + match decode_slice(packet_buf) { Ok(Some(Packet::Publish(p))) => { assert_eq!(p.dup, false); assert_eq!(p.retain, false); assert_eq!(p.qospid, QosPid::AtMostOnce); assert_eq!(p.topic_name, "a/b"); - assert_eq!(String::from_utf8(p.payload).unwrap(), "hello"); + assert_eq!(core::str::from_utf8(p.payload).unwrap(), "hello"); } other => panic!("Failed decode: {:?}", other), } - match decode(&mut data) { + + let packet_buf2 = &mut [0u8; 64]; + assert_eq!(clone_packet(&mut data, &mut packet_buf2[..]).unwrap(), 12); + // assert_eq!(data.len(), 14); + match decode_slice(packet_buf2) { Ok(Some(Packet::Publish(p))) => { assert_eq!(p.dup, true); assert_eq!(p.retain, false); assert_eq!(p.qospid, QosPid::AtMostOnce); assert_eq!(p.topic_name, "a/b"); - assert_eq!(String::from_utf8(p.payload).unwrap(), "hello"); + assert_eq!(core::str::from_utf8(p.payload).unwrap(), "hello"); } other => panic!("Failed decode: {:?}", other), } - match decode(&mut data) { + + let packet_buf3 = &mut [0u8; 64]; + assert_eq!(clone_packet(&mut data, &mut packet_buf3[..]).unwrap(), 14); + // assert_eq!(data.len(), 0); + + match decode_slice(packet_buf3) { Ok(Some(Packet::Publish(p))) => { assert_eq!(p.dup, true); assert_eq!(p.retain, true); assert_eq!(p.qospid, QosPid::from_u8u16(2, 10)); assert_eq!(p.topic_name, "a/b"); - assert_eq!(String::from_utf8(p.payload).unwrap(), "hello"); + assert_eq!(core::str::from_utf8(p.payload).unwrap(), "hello"); } other => panic!("Failed decode: {:?}", other), } @@ -151,8 +332,8 @@ fn test_publish() { #[test] fn test_pub_ack() { - let mut data = bm(&[0b01000000, 0b00000010, 0, 10]); - match decode(&mut data) { + let mut data: &[u8] = &[0b01000000, 0b00000010, 0, 10]; + match decode_slice(&mut data) { Ok(Some(Packet::Puback(a))) => assert_eq!(a.get(), 10), other => panic!("Failed decode: {:?}", other), }; @@ -160,8 +341,8 @@ fn test_pub_ack() { #[test] fn test_pub_rec() { - let mut data = bm(&[0b01010000, 0b00000010, 0, 10]); - match decode(&mut data) { + let mut data: &[u8] = &[0b01010000, 0b00000010, 0, 10]; + match decode_slice(&mut data) { Ok(Some(Packet::Pubrec(a))) => assert_eq!(a.get(), 10), other => panic!("Failed decode: {:?}", other), }; @@ -169,8 +350,8 @@ fn test_pub_rec() { #[test] fn test_pub_rel() { - let mut data = bm(&[0b01100010, 0b00000010, 0, 10]); - match decode(&mut data) { + let mut data: &[u8] = &[0b01100010, 0b00000010, 0, 10]; + match decode_slice(&mut data) { Ok(Some(Packet::Pubrel(a))) => assert_eq!(a.get(), 10), other => panic!("Failed decode: {:?}", other), }; @@ -178,8 +359,8 @@ fn test_pub_rel() { #[test] fn test_pub_comp() { - let mut data = bm(&[0b01110000, 0b00000010, 0, 10]); - match decode(&mut data) { + let mut data: &[u8] = &[0b01110000, 0b00000010, 0, 10]; + match decode_slice(&mut data) { Ok(Some(Packet::Pubcomp(a))) => assert_eq!(a.get(), 10), other => panic!("Failed decode: {:?}", other), }; @@ -187,17 +368,17 @@ fn test_pub_comp() { #[test] fn test_subscribe() { - let mut data = bm(&[ + let mut data: &[u8] = &[ 0b10000010, 8, 0, 10, 0, 3, 'a' as u8, '/' as u8, 'b' as u8, 0, - ]); - match decode(&mut data) { + ]; + match decode_slice(&mut data) { Ok(Some(Packet::Subscribe(s))) => { assert_eq!(s.pid.get(), 10); let t = SubscribeTopic { - topic_path: "a/b".to_string(), + topic_path: LimitedString::from("a/b"), qos: QoS::AtMostOnce, }; - assert_eq!(s.topics[0], t); + assert_eq!(s.topics.get(0), Some(&t)); } other => panic!("Failed decode: {:?}", other), } @@ -205,13 +386,13 @@ fn test_subscribe() { #[test] fn test_suback() { - let mut data = bm(&[0b10010000, 3, 0, 10, 0b00000010]); - match decode(&mut data) { + let mut data: &[u8] = &[0b10010000, 3, 0, 10, 0b00000010]; + match decode_slice(&mut data) { Ok(Some(Packet::Suback(s))) => { assert_eq!(s.pid.get(), 10); assert_eq!( - s.return_codes[0], - SubscribeReturnCodes::Success(QoS::ExactlyOnce) + s.return_codes.get(0), + Some(&SubscribeReturnCodes::Success(QoS::ExactlyOnce)) ); } other => panic!("Failed decode: {:?}", other), @@ -220,11 +401,11 @@ fn test_suback() { #[test] fn test_unsubscribe() { - let mut data = bm(&[0b10100010, 5, 0, 10, 0, 1, 'a' as u8]); - match decode(&mut data) { + let mut data: &[u8] = &[0b10100010, 5, 0, 10, 0, 1, 'a' as u8]; + match decode_slice(&mut data) { Ok(Some(Packet::Unsubscribe(a))) => { assert_eq!(a.pid.get(), 10); - assert_eq!(a.topics[0], 'a'.to_string()); + assert_eq!(a.topics.get(0), Some(&LimitedString::from("a"))); } other => panic!("Failed decode: {:?}", other), } @@ -232,8 +413,8 @@ fn test_unsubscribe() { #[test] fn test_unsub_ack() { - let mut data = bm(&[0b10110000, 2, 0, 10]); - match decode(&mut data) { + let mut data: &[u8] = &[0b10110000, 2, 0, 10]; + match decode_slice(&mut data) { Ok(Some(Packet::Unsuback(p))) => { assert_eq!(p.get(), 10); } diff --git a/src/encoder.rs b/src/encoder.rs index 1fa7433..7716203 100644 --- a/src/encoder.rs +++ b/src/encoder.rs @@ -1,5 +1,4 @@ use crate::{Error, Packet}; -use bytes::BufMut; /// Encode a [Packet] enum into a [BufMut] buffer. /// @@ -11,97 +10,105 @@ use bytes::BufMut; /// dup: false, /// qospid: QosPid::AtMostOnce, /// retain: false, -/// topic_name: "test".into(), -/// payload: "hello".into(), +/// topic_name: "test", +/// payload: b"hello", /// }.into(); /// /// // Allocate buffer (should be appropriately-sized or able to grow as needed). -/// let mut buf = BytesMut::with_capacity(1024); +/// let mut buf = [0u8; 1024]; /// /// // Write bytes corresponding to `&Packet` into the `BytesMut`. -/// encode(&packet, &mut buf).expect("failed encoding"); -/// assert_eq!(&*buf, &[0b00110000, 11, +/// let len = encode_slice(&packet, &mut buf).expect("failed encoding"); +/// assert_eq!(&buf[..len], &[0b00110000, 11, /// 0, 4, 't' as u8, 'e' as u8, 's' as u8, 't' as u8, /// 'h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8]); /// ``` /// /// [Packet]: ../enum.Packet.html /// [BufMut]: https://docs.rs/bytes/0.5.3/bytes/trait.BufMut.html -pub fn encode(packet: &Packet, buf: &mut impl BufMut) -> Result { +// #[cfg(feature = "std")] +// pub fn encode_slice(packet: &Packet, buf: impl BufMut) -> Result { +// let mut offset = 0; +// encode_slice(packet, buf.bytes_mut(), &mut offset) +// } + +pub fn encode_slice(packet: &Packet, buf: &mut [u8]) -> Result { + let mut offset = 0; + match packet { - Packet::Connect(connect) => connect.to_buffer(buf), - Packet::Connack(connack) => connack.to_buffer(buf), - Packet::Publish(publish) => publish.to_buffer(buf), + Packet::Connect(connect) => connect.to_buffer(buf, &mut offset), + Packet::Connack(connack) => connack.to_buffer(buf, &mut offset), + Packet::Publish(publish) => publish.to_buffer(buf, &mut offset), Packet::Puback(pid) => { - check_remaining(buf, 4)?; + check_remaining(buf, &mut offset, 4)?; let header: u8 = 0b01000000; let length: u8 = 2; - buf.put_u8(header); - buf.put_u8(length); - pid.to_buffer(buf)?; + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; + pid.to_buffer(buf, &mut offset)?; Ok(4) } Packet::Pubrec(pid) => { - check_remaining(buf, 4)?; + check_remaining(buf, &mut offset, 4)?; let header: u8 = 0b01010000; let length: u8 = 2; - buf.put_u8(header); - buf.put_u8(length); - pid.to_buffer(buf)?; + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; + pid.to_buffer(buf, &mut offset)?; Ok(4) } Packet::Pubrel(pid) => { - check_remaining(buf, 4)?; + check_remaining(buf, &mut offset, 4)?; let header: u8 = 0b01100010; let length: u8 = 2; - buf.put_u8(header); - buf.put_u8(length); - pid.to_buffer(buf)?; + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; + pid.to_buffer(buf, &mut offset)?; Ok(4) } Packet::Pubcomp(pid) => { - check_remaining(buf, 4)?; + check_remaining(buf, &mut offset, 4)?; let header: u8 = 0b01110000; let length: u8 = 2; - buf.put_u8(header); - buf.put_u8(length); - pid.to_buffer(buf)?; + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; + pid.to_buffer(buf, &mut offset)?; Ok(4) } - Packet::Subscribe(subscribe) => subscribe.to_buffer(buf), - Packet::Suback(suback) => suback.to_buffer(buf), - Packet::Unsubscribe(unsub) => unsub.to_buffer(buf), + Packet::Subscribe(subscribe) => subscribe.to_buffer(buf, &mut offset), + Packet::Suback(suback) => suback.to_buffer(buf, &mut offset), + Packet::Unsubscribe(unsub) => unsub.to_buffer(buf, &mut offset), Packet::Unsuback(pid) => { - check_remaining(buf, 4)?; + check_remaining(buf, &mut offset, 4)?; let header: u8 = 0b10110000; let length: u8 = 2; - buf.put_u8(header); - buf.put_u8(length); - pid.to_buffer(buf)?; + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; + pid.to_buffer(buf, &mut offset)?; Ok(4) } Packet::Pingreq => { - check_remaining(buf, 2)?; + check_remaining(buf, &mut offset, 2)?; let header: u8 = 0b11000000; let length: u8 = 0; - buf.put_u8(header); - buf.put_u8(length); + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; Ok(2) } Packet::Pingresp => { - check_remaining(buf, 2)?; + check_remaining(buf, &mut offset, 2)?; let header: u8 = 0b11010000; let length: u8 = 0; - buf.put_u8(header); - buf.put_u8(length); + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; Ok(2) } Packet::Disconnect => { - check_remaining(buf, 2)?; + check_remaining(buf, &mut offset, 2)?; let header: u8 = 0b11100000; let length: u8 = 0; - buf.put_u8(header); - buf.put_u8(length); + write_u8(buf, &mut offset, header)?; + write_u8(buf, &mut offset, length)?; Ok(2) } } @@ -109,8 +116,8 @@ pub fn encode(packet: &Packet, buf: &mut impl BufMut) -> Result { /// Check wether buffer has `len` bytes of write capacity left. Use this to return a clean /// Result::Err instead of panicking. -pub(crate) fn check_remaining(buf: &impl BufMut, len: usize) -> Result<(), Error> { - if buf.remaining_mut() < len { +pub(crate) fn check_remaining(buf: &mut [u8], offset: &mut usize, len: usize) -> Result<(), Error> { + if buf[*offset..].len() < len { Err(Error::WriteZero) } else { Ok(()) @@ -118,24 +125,24 @@ pub(crate) fn check_remaining(buf: &impl BufMut, len: usize) -> Result<(), Error } /// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718023 -pub(crate) fn write_length(len: usize, buf: &mut impl BufMut) -> Result { +pub(crate) fn write_length(buf: &mut [u8], offset: &mut usize, len: usize) -> Result { let write_len = match len { 0..=127 => { - check_remaining(buf, len + 1)?; + check_remaining(buf, offset, len + 1)?; len + 1 - }, + } 128..=16383 => { - check_remaining(buf, len + 2)?; + check_remaining(buf, offset, len + 2)?; len + 2 - }, + } 16384..=2097151 => { - check_remaining(buf, len + 3)?; + check_remaining(buf, offset, len + 3)?; len + 3 - }, + } 2097152..=268435455 => { - check_remaining(buf, len + 4)?; + check_remaining(buf, offset, len + 4)?; len + 4 - }, + } _ => return Err(Error::InvalidLength), }; let mut done = false; @@ -146,18 +153,33 @@ pub(crate) fn write_length(len: usize, buf: &mut impl BufMut) -> Result 0 { byte = byte | 128; } - buf.put_u8(byte); + write_u8(buf, offset, byte)?; done = x <= 0; } Ok(write_len) } -pub(crate) fn write_bytes(bytes: &[u8], buf: &mut impl BufMut) -> Result<(), Error> { - buf.put_u16(bytes.len() as u16); - buf.put_slice(bytes); + +pub(crate) fn write_u8(buf: &mut [u8], offset: &mut usize, val: u8) -> Result<(), Error> { + buf[*offset] = val; + *offset += 1; + Ok(()) +} + +pub(crate) fn write_u16(buf: &mut [u8], offset: &mut usize, val: u16) -> Result<(), Error> { + write_u8(buf, offset, (val >> 8) as u8)?; + write_u8(buf, offset, (val & 0xFF) as u8) +} + +pub(crate) fn write_bytes(buf: &mut [u8], offset: &mut usize, bytes: &[u8]) -> Result<(), Error> { + write_u16(buf, offset, bytes.len() as u16)?; + + for &byte in bytes { + write_u8(buf, offset, byte)?; + } Ok(()) } -pub(crate) fn write_string(string: &str, buf: &mut impl BufMut) -> Result<(), Error> { - write_bytes(string.as_bytes(), buf) +pub(crate) fn write_string(buf: &mut [u8], offset: &mut usize, string: &str) -> Result<(), Error> { + write_bytes(buf, offset, string.as_bytes(), ) } diff --git a/src/encoder_test.rs b/src/encoder_test.rs index a2f295d..918a6d3 100644 --- a/src/encoder_test.rs +++ b/src/encoder_test.rs @@ -1,15 +1,32 @@ use crate::*; -use bytes::BytesMut; use core::convert::TryFrom; -use alloc::string::ToString; -use alloc::vec; - -macro_rules! assert_decode { - ($res:pat, $pkt:expr) => { - let mut buf = BytesMut::with_capacity(1024); - let written = encode($pkt, &mut buf).unwrap(); - assert_eq!(written, buf.len()); - match decode(&mut buf) { +use subscribe::{LimitedString, LimitedVec}; + +#[cfg(feature = "std")] +use bytes::BytesMut; + +// macro_rules! assert_decode { +// ($res:pat, $pkt:expr) => { +// let mut buf = BytesMut::with_capacity(1024); +// let written = encode($pkt, &mut buf).unwrap(); +// assert_eq!(written, buf.len()); +// match decode_slice(&mut buf) { +// Ok(Some($res)) => (), +// err => assert!( +// false, +// "Expected: Ok(Some({})) got: {:?}", +// stringify!($res), +// err +// ), +// } +// }; +// } +macro_rules! assert_decode_slice { + ($res:pat, $pkt:expr, $written_exp:expr) => { + let mut slice = [0u8; 512]; + let written = encode_slice($pkt, &mut slice).unwrap(); + assert_eq!(written, $written_exp); + match decode_slice(&slice[..written]) { Ok(Some($res)) => (), err => assert!( false, @@ -26,13 +43,39 @@ fn test_connect() { let packet = Connect { protocol: Protocol::new("MQTT", 4).unwrap(), keep_alive: 120, - client_id: "imvj".to_string(), + client_id: "imvj", clean_session: true, last_will: None, username: None, password: None, - }; - assert_decode!(Packet::Connect(_), &packet.into()); + } + .into(); + // assert_decode!(Packet::Connect(_), &packet); + assert_decode_slice!(Packet::Connect(_), &packet, 18); +} + +#[test] +fn test_write_zero() { + let packet = Connect { + protocol: Protocol::new("MQTT", 4).unwrap(), + keep_alive: 120, + client_id: "imvj", + clean_session: true, + last_will: None, + username: None, + password: None, + } + .into(); + + let mut slice = [0u8; 8]; + match encode_slice(&packet, &mut slice) { + Ok(_) => panic!("Expected Error::WriteZero, as input slice is too small"), + Err(e) => assert_eq!(e, Error::WriteZero), + } + + let mut buf = [0u8; 80]; + let written = encode_slice(&packet, &mut buf).unwrap(); + assert_eq!(written, 18); } #[test] @@ -40,8 +83,10 @@ fn test_connack() { let packet = Connack { session_present: true, code: ConnectReturnCode::Accepted, - }; - assert_decode!(Packet::Connack(_), &packet.into()); + } + .into(); + // assert_decode!(Packet::Connack(_), &packet); + assert_decode_slice!(Packet::Connack(_), &packet, 4); } #[test] @@ -50,85 +95,95 @@ fn test_publish() { dup: false, qospid: QosPid::from_u8u16(2, 10), retain: true, - topic_name: "asdf".to_string(), - payload: vec!['h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8], - }; - assert_decode!(Packet::Publish(_), &packet.into()); + topic_name: "asdf", + payload: &['h' as u8, 'e' as u8, 'l' as u8, 'l' as u8, 'o' as u8], + } + .into(); + // assert_decode!(Packet::Publish(_), &packet); + assert_decode_slice!(Packet::Publish(_), &packet, 15); } #[test] fn test_puback() { let packet = Packet::Puback(Pid::try_from(19).unwrap()); - assert_decode!(Packet::Puback(_), &packet); + // assert_decode!(Packet::Puback(_), &packet); + assert_decode_slice!(Packet::Puback(_), &packet, 4); } #[test] fn test_pubrec() { let packet = Packet::Pubrec(Pid::try_from(19).unwrap()); - assert_decode!(Packet::Pubrec(_), &packet); + // assert_decode!(Packet::Pubrec(_), &packet); + assert_decode_slice!(Packet::Pubrec(_), &packet, 4); } #[test] fn test_pubrel() { let packet = Packet::Pubrel(Pid::try_from(19).unwrap()); - assert_decode!(Packet::Pubrel(_), &packet); + // assert_decode!(Packet::Pubrel(_), &packet); + assert_decode_slice!(Packet::Pubrel(_), &packet, 4); } #[test] fn test_pubcomp() { let packet = Packet::Pubcomp(Pid::try_from(19).unwrap()); - assert_decode!(Packet::Pubcomp(_), &packet); + // assert_decode!(Packet::Pubcomp(_), &packet); + assert_decode_slice!(Packet::Pubcomp(_), &packet, 4); } #[test] fn test_subscribe() { let stopic = SubscribeTopic { - topic_path: "a/b".to_string(), + topic_path: LimitedString::from("a/b"), qos: QoS::ExactlyOnce, }; - let packet = Subscribe { - pid: Pid::try_from(345).unwrap(), - topics: vec![stopic], - }; - assert_decode!(Packet::Subscribe(_), &Packet::Subscribe(packet)); + let topics: LimitedVec = [stopic].iter().cloned().collect(); + let packet = Subscribe::new(Pid::try_from(345).unwrap(), topics).into(); + // assert_decode!(Packet::Subscribe(_), &packet); + assert_decode_slice!(Packet::Subscribe(_), &packet, 10); } #[test] fn test_suback() { - let return_code = SubscribeReturnCodes::Success(QoS::ExactlyOnce); - let packet = Suback { - pid: Pid::try_from(12321).unwrap(), - return_codes: vec![return_code], - }; - assert_decode!(Packet::Suback(_), &Packet::Suback(packet)); + let return_codes = [SubscribeReturnCodes::Success(QoS::ExactlyOnce)] + .iter() + .cloned() + .collect(); + let packet = Suback::new(Pid::try_from(12321).unwrap(), return_codes).into(); + // assert_decode!(Packet::Suback(_), &packet); + assert_decode_slice!(Packet::Suback(_), &packet, 5); } #[test] fn test_unsubscribe() { - let packet = Unsubscribe { - pid: Pid::try_from(12321).unwrap(), - topics: vec!["a/b".to_string()], - }; - assert_decode!(Packet::Unsubscribe(_), &Packet::Unsubscribe(packet)); + let topics: LimitedVec = [LimitedString::from("a/b")].iter().cloned().collect(); + + let packet = Unsubscribe::new(Pid::try_from(12321).unwrap(), topics).into(); + // assert_decode!(Packet::Unsubscribe(_), &packet); + assert_decode_slice!(Packet::Unsubscribe(_), &packet, 9); } #[test] fn test_unsuback() { let packet = Packet::Unsuback(Pid::try_from(19).unwrap()); - assert_decode!(Packet::Unsuback(_), &packet); + // assert_decode!(Packet::Unsuback(_), &packet); + assert_decode_slice!(Packet::Unsuback(_), &packet, 4); } #[test] fn test_ping_req() { - assert_decode!(Packet::Pingreq, &Packet::Pingreq); + // assert_decode!(Packet::Pingreq, &Packet::Pingreq); + assert_decode_slice!(Packet::Pingreq, &Packet::Pingreq, 2); } #[test] fn test_ping_resp() { - assert_decode!(Packet::Pingresp, &Packet::Pingresp); + // assert_decode!(Packet::Pingresp, &Packet::Pingresp); + assert_decode_slice!(Packet::Pingresp, &Packet::Pingresp, 2); } #[test] fn test_disconnect() { - assert_decode!(Packet::Disconnect, &Packet::Disconnect); + // assert_decode!(Packet::Disconnect, &Packet::Disconnect); + assert_decode_slice!(Packet::Disconnect, &Packet::Disconnect, 2); } diff --git a/src/lib.rs b/src/lib.rs index e596fb4..e01158e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,36 +12,36 @@ //! use bytes::BytesMut; //! //! // Allocate buffer. -//! let mut buf = BytesMut::with_capacity(1024); +//! let mut buf = [0u8; 1024]; //! //! // Encode an MQTT Connect packet. //! let pkt = Packet::Connect(Connect { protocol: Protocol::MQTT311, //! keep_alive: 30, -//! client_id: "doc_client".into(), +//! client_id: "doc_client", //! clean_session: true, //! last_will: None, //! username: None, //! password: None }); -//! assert!(encode(&pkt, &mut buf).is_ok()); -//! assert_eq!(&buf[14..], "doc_client".as_bytes()); +//! let len = encode_slice(&pkt, &mut buf).unwrap(); +//! assert_eq!(&buf[14..len], b"doc_client"); //! let mut encoded = buf.clone(); //! //! // Decode one packet. The buffer will advance to the next packet. -//! assert_eq!(Ok(Some(pkt)), decode(&mut buf)); +//! assert_eq!(Ok(Some(pkt)), decode_slice(&mut buf)); //! //! // Example decode failures. -//! let mut incomplete = encoded.split_to(10); -//! assert_eq!(Ok(None), decode(&mut incomplete)); +//! let mut incomplete = encoded.split_at(10).0; +//! assert_eq!(Ok(None), decode_slice(&mut incomplete)); //! let mut garbage = BytesMut::from(&[0u8,0,0,0] as &[u8]); -//! assert_eq!(Err(Error::InvalidHeader), decode(&mut garbage)); +//! assert_eq!(Err(Error::InvalidHeader), decode_slice(&mut garbage)); //! ``` //! //! [MQTT 3.1]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html //! [MQTT 5]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html //! [tokio]: https://tokio.rs/ //! [Packet]: enum.Packet.html -//! [encode()]: fn.encode.html -//! [decode()]: fn.decode.html +//! [encode_slice()]: fn.encode_slice.html +//! [decode_slice()]: fn.decode_slice.html //! [bytes::BytesMut]: https://docs.rs/bytes/0.5.3/bytes/struct.BytesMut.html #![cfg_attr(not(test), no_std)] @@ -49,8 +49,6 @@ #[cfg(feature = "std")] extern crate std; -extern crate alloc; - mod connect; mod decoder; mod encoder; @@ -59,8 +57,11 @@ mod publish; mod subscribe; mod utils; -#[cfg(test)] -mod codec_test; +// Proptest does not currently support borrowed data in strategies: +// https://github.com/AltSysrq/proptest/issues/9 +// +// #[cfg(test)] +// mod codec_test; #[cfg(test)] mod decoder_test; #[cfg(test)] @@ -68,8 +69,8 @@ mod encoder_test; pub use crate::{ connect::{Connack, Connect, ConnectReturnCode, LastWill, Protocol}, - decoder::decode, - encoder::encode, + decoder::{clone_packet, decode_slice}, + encoder::encode_slice, packet::{Packet, PacketType}, publish::Publish, subscribe::{Suback, Subscribe, SubscribeReturnCodes, SubscribeTopic, Unsubscribe}, diff --git a/src/packet.rs b/src/packet.rs index f7040ee..d4c7a58 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -2,7 +2,7 @@ use crate::*; /// Base enum for all MQTT packet types. /// -/// This is the main type you'll be interacting with, as an output of [`decode()`] and an input of +/// This is the main type you'll be interacting with, as an output of [`decode_slice()`] and an input of /// [`encode()`]. Most variants can be constructed directly without using methods. /// /// ``` @@ -15,23 +15,23 @@ use crate::*; /// let publish = Publish { dup: false, /// qospid: QosPid::AtMostOnce, /// retain: false, -/// topic_name: "to/pic".into(), -/// payload: "payload".into() }; +/// topic_name: "to/pic", +/// payload: b"payload" }; /// let pkt: Packet = publish.into(); /// // Identifyer-only packets /// let pkt = Packet::Puback(Pid::try_from(42).unwrap()); /// ``` /// /// [`encode()`]: fn.encode.html -/// [`decode()`]: fn.decode.html +/// [`decode_slice()`]: fn.decode_slice.html #[derive(Debug, Clone, PartialEq)] -pub enum Packet { +pub enum Packet<'a> { /// [MQTT 3.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028) - Connect(Connect), + Connect(Connect<'a>), /// [MQTT 3.2](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033) Connack(Connack), /// [MQTT 3.3](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037) - Publish(Publish), + Publish(Publish<'a>), /// [MQTT 3.4](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718043) Puback(Pid), /// [MQTT 3.5](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718048) @@ -55,7 +55,7 @@ pub enum Packet { /// [MQTT 3.14](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718090) Disconnect, } -impl Packet { +impl<'a> Packet<'a> { /// Return the packet type variant. /// /// This can be used for matching, categorising, debuging, etc. Most users will match directly @@ -79,10 +79,22 @@ impl Packet { } } } + +macro_rules! packet_from_borrowed { + ($($t:ident),+) => { + $( + impl<'a> From<$t<'a>> for Packet<'a> { + fn from(p: $t<'a>) -> Self { + Packet::$t(p) + } + } + )+ + } +} macro_rules! packet_from { ($($t:ident),+) => { $( - impl From<$t> for Packet { + impl<'a> From<$t> for Packet<'a> { fn from(p: $t) -> Self { Packet::$t(p) } @@ -90,7 +102,9 @@ macro_rules! packet_from { )+ } } -packet_from!(Connect, Connack, Publish, Subscribe, Suback, Unsubscribe); + +packet_from_borrowed!(Connect, Publish); +packet_from!(Suback, Connack, Subscribe, Unsubscribe); /// Packet type variant, without the associated data. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] diff --git a/src/publish.rs b/src/publish.rs index c7d651e..e48ffba 100644 --- a/src/publish.rs +++ b/src/publish.rs @@ -1,27 +1,35 @@ use crate::{decoder::*, encoder::*, *}; -use alloc::{string::String, vec::Vec}; -use bytes::{BufMut, BytesMut}; + +// use alloc::{string::String, vec::Vec}; +use heapless::{String, Vec, consts}; + /// Publish packet ([MQTT 3.3]). /// /// [MQTT 3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037 #[derive(Debug, Clone, PartialEq)] -pub struct Publish { +pub struct Publish<'a> { pub dup: bool, pub qospid: QosPid, pub retain: bool, - pub topic_name: String, - pub payload: Vec, + pub topic_name: &'a str, + pub payload: &'a [u8], } -impl Publish { - pub(crate) fn from_buffer(header: &Header, buf: &mut BytesMut) -> Result { - let topic_name = read_string(buf)?; +impl<'a> Publish<'a> { + pub(crate) fn from_buffer( + header: &Header, + remaining_len: usize, + buf: &'a [u8], + offset: &mut usize, + ) -> Result { + let payload_end = *offset + remaining_len; + let topic_name = read_str(buf, offset)?; let qospid = match header.qos { QoS::AtMostOnce => QosPid::AtMostOnce, - QoS::AtLeastOnce => QosPid::AtLeastOnce(Pid::from_buffer(buf)?), - QoS::ExactlyOnce => QosPid::ExactlyOnce(Pid::from_buffer(buf)?), + QoS::AtLeastOnce => QosPid::AtLeastOnce(Pid::from_buffer(buf, offset)?), + QoS::ExactlyOnce => QosPid::ExactlyOnce(Pid::from_buffer(buf, offset)?), }; Ok(Publish { @@ -29,10 +37,10 @@ impl Publish { qospid, retain: header.retain, topic_name, - payload: buf.to_vec(), + payload: &buf[*offset..payload_end], }) } - pub(crate) fn to_buffer(&self, buf: &mut impl BufMut) -> Result { + pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { // Header let mut header: u8 = match self.qospid { QosPid::AtMostOnce => 0b00110000, @@ -45,8 +53,8 @@ impl Publish { if self.retain { header |= 0b00000001 as u8; }; - check_remaining(buf, 1)?; - buf.put_u8(header); + check_remaining(buf, offset, 1)?; + write_u8(buf, offset, header)?; // Length: topic (2+len) + pid (0/2) + payload (len) let length = self.topic_name.len() @@ -56,20 +64,22 @@ impl Publish { } + self.payload.len(); - let write_len = write_length(length, buf)? + 1; + let write_len = write_length(buf, offset, length)? + 1; // Topic - write_string(self.topic_name.as_ref(), buf)?; + write_string(buf, offset, self.topic_name)?; // Pid match self.qospid { QosPid::AtMostOnce => (), - QosPid::AtLeastOnce(pid) => pid.to_buffer(buf)?, - QosPid::ExactlyOnce(pid) => pid.to_buffer(buf)?, + QosPid::AtLeastOnce(pid) => pid.to_buffer(buf, offset)?, + QosPid::ExactlyOnce(pid) => pid.to_buffer(buf, offset)?, } // Payload - buf.put_slice(self.payload.as_slice()); + for &byte in self.payload { + write_u8(buf, offset, byte)?; + } Ok(write_len) } diff --git a/src/subscribe.rs b/src/subscribe.rs index b8dfecf..5b124d1 100644 --- a/src/subscribe.rs +++ b/src/subscribe.rs @@ -1,8 +1,16 @@ use crate::{decoder::*, encoder::*, *}; -use bytes::{Buf, BufMut, BytesMut}; #[cfg(feature = "derive")] use serde::{Deserialize, Serialize}; -use alloc::{string::String, vec::Vec}; + +#[cfg(feature = "std")] +pub(crate) type LimitedVec = std::vec::Vec; +#[cfg(not(feature = "std"))] +pub(crate) type LimitedVec = heapless::Vec; + +#[cfg(feature = "std")] +pub(crate) type LimitedString = std::string::String; +#[cfg(not(feature = "std"))] +pub(crate) type LimitedString = heapless::String; /// Subscribe topic. /// @@ -12,10 +20,19 @@ use alloc::{string::String, vec::Vec}; #[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "derive", derive(Serialize, Deserialize))] pub struct SubscribeTopic { - pub topic_path: String, + pub topic_path: LimitedString, pub qos: QoS, } +impl SubscribeTopic { + pub(crate) fn from_buffer(buf: &[u8], offset: &mut usize) -> Result { + let topic_path = LimitedString::from(read_str(buf, offset)?); + let qos = QoS::from_u8(buf[*offset])?; + *offset += 1; + Ok(SubscribeTopic { topic_path, qos }) + } +} + /// Subscribe return value. /// /// [Suback] packets contain a `Vec` of those. @@ -26,7 +43,19 @@ pub enum SubscribeReturnCodes { Success(QoS), Failure, } + impl SubscribeReturnCodes { + pub(crate) fn from_buffer<'a>(buf: &'a [u8], offset: &mut usize) -> Result { + let code = buf[*offset]; + *offset += 1; + + if code == 0x80 { + Ok(SubscribeReturnCodes::Failure) + } else { + Ok(SubscribeReturnCodes::Success(QoS::from_u8(code)?)) + } + } + pub(crate) fn to_u8(&self) -> u8 { match *self { SubscribeReturnCodes::Failure => 0x80, @@ -41,7 +70,7 @@ impl SubscribeReturnCodes { #[derive(Debug, Clone, PartialEq)] pub struct Subscribe { pub pid: Pid, - pub topics: Vec, + pub topics: LimitedVec, } /// Subsack packet ([MQTT 3.9]). @@ -50,7 +79,7 @@ pub struct Subscribe { #[derive(Debug, Clone, PartialEq)] pub struct Suback { pub pid: Pid, - pub return_codes: Vec, + pub return_codes: LimitedVec, } /// Unsubscribe packet ([MQTT 3.10]). @@ -59,41 +88,52 @@ pub struct Suback { #[derive(Debug, Clone, PartialEq)] pub struct Unsubscribe { pub pid: Pid, - pub topics: Vec, + pub topics: LimitedVec, } impl Subscribe { - pub(crate) fn from_buffer(buf: &mut BytesMut) -> Result { - let pid = Pid::from_buffer(buf)?; - let mut topics: Vec = Vec::new(); - while buf.remaining() != 0 { - let topic_path = read_string(buf)?; - let qos = QoS::from_u8(buf.get_u8())?; - let topic = SubscribeTopic { topic_path, qos }; - topics.push(topic); + pub fn new(pid: Pid, topics: LimitedVec) -> Self { + Subscribe { pid, topics } + } + + pub(crate) fn from_buffer( + remaining_len: usize, + buf: &[u8], + offset: &mut usize, + ) -> Result { + let payload_end = *offset + remaining_len; + let pid = Pid::from_buffer(buf, offset)?; + + let mut topics = LimitedVec::new(); + while *offset < payload_end { + let _res = topics.push(SubscribeTopic::from_buffer(buf, offset)?); + + #[cfg(not(feature = "std"))] + _res.map_err(|_| Error::InvalidLength)?; } + Ok(Subscribe { pid, topics }) } - pub(crate) fn to_buffer(&self, buf: &mut impl BufMut) -> Result { + pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { let header: u8 = 0b10000010; - check_remaining(buf, 1)?; - buf.put_u8(header); + check_remaining(buf, offset, 1)?; + write_u8(buf, offset, header)?; // Length: pid(2) + topic.for_each(2+len + qos(1)) let mut length = 2; for topic in &self.topics { length += topic.topic_path.len() + 2 + 1; } - let write_len = write_length(length, buf)? + 1; + let write_len = write_length(buf, offset, length)? + 1; // Pid - self.pid.to_buffer(buf)?; + self.pid.to_buffer(buf, offset)?; // Topics for topic in &self.topics { - write_string(topic.topic_path.as_ref(), buf)?; - buf.put_u8(topic.qos.to_u8()); + write_string(buf, offset, topic.topic_path.as_str())?; + write_u8(buf, offset, topic.qos.to_u8())?; } Ok(write_len) @@ -101,59 +141,81 @@ impl Subscribe { } impl Unsubscribe { - pub(crate) fn from_buffer(buf: &mut BytesMut) -> Result { - let pid = Pid::from_buffer(buf)?; - let mut topics: Vec = Vec::new(); - while buf.remaining() != 0 { - let topic_path = read_string(buf)?; - topics.push(topic_path); + pub fn new(pid: Pid, topics: LimitedVec) -> Self { + Unsubscribe { pid, topics } + } + + pub(crate) fn from_buffer( + remaining_len: usize, + buf: &[u8], + offset: &mut usize, + ) -> Result { + let payload_end = *offset + remaining_len; + let pid = Pid::from_buffer(buf, offset)?; + + let mut topics = LimitedVec::new(); + while *offset < payload_end { + let _res = topics.push(LimitedString::from(read_str(buf, offset)?)); + + #[cfg(not(feature = "std"))] + _res.map_err(|_| Error::InvalidLength)?; } + Ok(Unsubscribe { pid, topics }) } - pub(crate) fn to_buffer(&self, buf: &mut impl BufMut) -> Result { + pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { let header: u8 = 0b10100010; let mut length = 2; for topic in &self.topics { length += 2 + topic.len(); } - check_remaining(buf, 1)?; - buf.put_u8(header); + check_remaining(buf, offset, 1)?; + write_u8(buf, offset, header)?; - let write_len = write_length(length, buf)? + 1; - self.pid.to_buffer(buf)?; + let write_len = write_length(buf, offset, length)? + 1; + self.pid.to_buffer(buf, offset)?; for topic in &self.topics { - write_string(topic.as_ref(), buf)?; + write_string(buf, offset, topic)?; } Ok(write_len) } } impl Suback { - pub(crate) fn from_buffer(buf: &mut impl Buf) -> Result { - let pid = Pid::from_buffer(buf)?; - let mut return_codes: Vec = Vec::new(); - while buf.remaining() != 0 { - let code = buf.get_u8(); - let r = if code == 0x80 { - SubscribeReturnCodes::Failure - } else { - SubscribeReturnCodes::Success(QoS::from_u8(code)?) - }; - return_codes.push(r); + pub fn new(pid: Pid, return_codes: LimitedVec) -> Self { + Suback { pid, return_codes } + } + + pub(crate) fn from_buffer( + remaining_len: usize, + buf: &[u8], + offset: &mut usize, + ) -> Result { + let payload_end = *offset + remaining_len; + let pid = Pid::from_buffer(buf, offset)?; + + let mut return_codes = LimitedVec::new(); + while *offset < payload_end { + let _res = return_codes.push(SubscribeReturnCodes::from_buffer(buf, offset)?); + + #[cfg(not(feature = "std"))] + _res.map_err(|_| Error::InvalidLength)?; } - Ok(Suback { return_codes, pid }) + + Ok(Suback { pid, return_codes }) } - pub(crate) fn to_buffer(&self, buf: &mut impl BufMut) -> Result { + + pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { let header: u8 = 0b10010000; let length = 2 + self.return_codes.len(); - check_remaining(buf, 1)?; - buf.put_u8(header); + check_remaining(buf, offset, 1)?; + write_u8(buf, offset, header)?; - let write_len = write_length(length, buf)? + 1; - self.pid.to_buffer(buf)?; + let write_len = write_length(buf, offset, length)? + 1; + self.pid.to_buffer(buf, offset)?; for rc in &self.return_codes { - buf.put_u8(rc.to_u8()); + write_u8(buf, offset, rc.to_u8())?; } Ok(write_len) } diff --git a/src/utils.rs b/src/utils.rs index 4a02ad1..b852959 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,15 +1,13 @@ -use alloc::string::String; -use bytes::{Buf, BufMut}; use core::{convert::TryFrom, fmt, num::NonZeroU16}; +use crate::encoder::write_u16; #[cfg(feature = "derive")] use serde::{Deserialize, Serialize}; -#[cfg(feature = "std")] -use alloc::format; #[cfg(feature = "std")] use std::{ error::Error as ErrorTrait, + format, io::{Error as IoError, ErrorKind}, }; @@ -30,7 +28,10 @@ pub enum Error { /// Tried to decode a ConnectReturnCode > 5. InvalidConnectReturnCode(u8), /// Tried to decode an unknown protocol. - InvalidProtocol(String, u8), + #[cfg(feature = "std")] + InvalidProtocol(std::string::String, u8), + #[cfg(not(feature = "std"))] + InvalidProtocol(heapless::String, u8), /// Tried to decode an invalid fixed header (packet type, flags, or remaining_length). InvalidHeader, /// Trying to encode/decode an invalid length. @@ -45,7 +46,7 @@ pub enum Error { /// Note: Only available when std is available. /// You'll hopefully never see this. #[cfg(feature = "std")] - IoError(ErrorKind, String), + IoError(ErrorKind, std::string::String), } #[cfg(feature = "std")] @@ -113,24 +114,32 @@ impl Pid { pub fn new() -> Self { Pid(NonZeroU16::new(1).unwrap()) } + /// Get the `Pid` as a raw `u16`. pub fn get(self) -> u16 { self.0.get() } - pub(crate) fn from_buffer(buf: &mut impl Buf) -> Result { - Self::try_from(buf.get_u16()) + + pub(crate) fn from_buffer<'a>(buf: &'a [u8], offset: &mut usize) -> Result { + let pid = ((buf[*offset] as u16) << 8) | buf[*offset + 1] as u16; + *offset += 2; + Self::try_from(pid) } - pub(crate) fn to_buffer(self, buf: &mut impl BufMut) -> Result<(), Error> { - Ok(buf.put_u16(self.get())) + + pub(crate) fn to_buffer(self, buf: &mut [u8], offset: &mut usize) -> Result<(), Error> { + write_u16(buf, offset, self.get()) } } + impl Default for Pid { fn default() -> Pid { Pid::new() } } + impl core::ops::Add for Pid { type Output = Pid; + /// Adding a `u16` to a `Pid` will wrap around and avoid 0. fn add(self, u: u16) -> Pid { let n = match self.get().overflowing_add(u) { @@ -140,8 +149,10 @@ impl core::ops::Add for Pid { Pid(NonZeroU16::new(n).unwrap()) } } + impl core::ops::Sub for Pid { type Output = Pid; + /// Adding a `u16` to a `Pid` will wrap around and avoid 0. fn sub(self, u: u16) -> Pid { let n = match self.get().overflowing_sub(u) { @@ -152,14 +163,17 @@ impl core::ops::Sub for Pid { Pid(NonZeroU16::new(n).unwrap()) } } + impl From for u16 { /// Convert `Pid` to `u16`. fn from(p: Pid) -> Self { p.0.get() } } + impl TryFrom for Pid { type Error = Error; + /// Convert `u16` to `Pid`. Will fail for value 0. fn try_from(u: u16) -> Result { match NonZeroU16::new(u) { @@ -182,6 +196,7 @@ pub enum QoS { /// `QoS 2`. Two acks needed. ExactlyOnce, } + impl QoS { pub(crate) fn to_u8(&self) -> u8 { match *self { @@ -190,6 +205,7 @@ impl QoS { QoS::ExactlyOnce => 2, } } + pub(crate) fn from_u8(byte: u8) -> Result { match byte { 0 => Ok(QoS::AtMostOnce), @@ -214,6 +230,7 @@ pub enum QosPid { AtLeastOnce(Pid), ExactlyOnce(Pid), } + impl QosPid { #[cfg(test)] pub(crate) fn from_u8u16(qos: u8, pid: u16) -> Self { @@ -224,6 +241,7 @@ impl QosPid { _ => panic!("Qos > 2"), } } + /// Extract the [`Pid`] from a `QosPid`, if any. /// /// [`Pid`]: struct.Pid.html @@ -234,6 +252,7 @@ impl QosPid { QosPid::ExactlyOnce(p) => Some(p), } } + /// Extract the [`QoS`] from a `QosPid`. /// /// [`QoS`]: enum.QoS.html @@ -249,9 +268,8 @@ impl QosPid { #[cfg(test)] mod test { use crate::Pid; - use alloc::vec; - use alloc::vec::Vec; use core::convert::TryFrom; + use std::vec; #[test] fn pid_add_sub() {