Skip to content

Commit

Permalink
chore: remove unnecessary MQTT trait impls
Browse files Browse the repository at this point in the history
  • Loading branch information
GunnarMorrigan committed Dec 3, 2024
1 parent b51f3b4 commit fa9ad2e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 112 deletions.
6 changes: 4 additions & 2 deletions mqrstt/src/packets/connect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ impl Default for Connect {

impl PacketRead for Connect {
fn read(_: u8, _: usize, mut buf: Bytes) -> Result<Self, DeserializeError> {
if String::read(&mut buf)? != "MQTT" {
return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_string()));
let expected_protocol = [b'M', b'Q', b'T', b'T'];
let received_protocol = Vec::<u8>::read(&mut buf)?;
if &received_protocol != &expected_protocol {
return Err(DeserializeError::MalformedPacketWithInfo("Protocol not MQTT".to_owned()));
}

let protocol_version = ProtocolVersion::read(&mut buf)?;
Expand Down
164 changes: 54 additions & 110 deletions mqrstt/src/packets/mqtt_trait/primitive_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use super::MqttAsyncWrite;
impl MqttRead for Box<str> {
#[inline]
fn read(buf: &mut Bytes) -> Result<Self, DeserializeError> {
let content = Bytes::read(buf)?;
let content = Vec::<u8>::read(buf)?;

match String::from_utf8(content.to_vec()) {
Ok(s) => Ok(s.into()),
match String::from_utf8(content) {
Ok(s) => Ok(s.into_boxed_str()),
Err(e) => Err(DeserializeError::Utf8Error(e)),
}
}
Expand Down Expand Up @@ -86,117 +86,61 @@ impl WireLength for &str {
}
}

impl MqttRead for String {
#[inline]
fn read(buf: &mut Bytes) -> Result<Self, DeserializeError> {
let content = Bytes::read(buf)?;
// impl MqttRead for Bytes {
// #[inline]
// fn read(buf: &mut Bytes) -> Result<Self, DeserializeError> {
// if buf.len() < 2 {
// return Err(DeserializeError::InsufficientData(std::any::type_name::<Bytes>(), buf.len(), 2));
// }
// let len = buf.get_u16() as usize;

match String::from_utf8(content.to_vec()) {
Ok(s) => Ok(s),
Err(e) => Err(DeserializeError::Utf8Error(e)),
}
}
}
// if len > buf.len() {
// return Err(DeserializeError::InsufficientData(std::any::type_name::<Bytes>(), buf.len(), len));
// }

impl<T> MqttAsyncRead<T> for String
where
T: tokio::io::AsyncReadExt + std::marker::Unpin,
{
async fn async_read(buf: &mut T) -> Result<(Self, usize), ReadError> {
let (content, read_bytes) = Bytes::async_read(buf).await?;
match String::from_utf8(content.to_vec()) {
Ok(s) => Ok((s, read_bytes)),
Err(e) => Err(ReadError::DeserializeError(DeserializeError::Utf8Error(e))),
}
}
}
// Ok(buf.split_to(len))
// }
// }
// impl<S> MqttAsyncRead<S> for Bytes
// where
// S: tokio::io::AsyncReadExt + std::marker::Unpin,
// {
// async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> {
// let size = stream.read_u16().await? as usize;
// // let mut data = BytesMut::with_capacity(size);
// let mut data = Vec::with_capacity(size);
// let read_bytes = stream.read_exact(&mut data).await?;
// assert_eq!(size, read_bytes);
// Ok((data.into(), 2 + size))
// }
// }
// impl MqttWrite for Bytes {
// #[inline]
// fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
// buf.put_u16(self.len() as u16);
// buf.extend(self);

impl MqttWrite for String {
#[inline]
fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
if self.len() > 65535 {
return Err(SerializeError::StringTooLong(self.len()));
}
// Ok(())
// }
// }
// impl<S> MqttAsyncWrite<S> for Bytes
// where
// S: tokio::io::AsyncWrite + Unpin,
// {
// async fn async_write(&self, stream: &mut S) -> Result<usize, crate::packets::error::WriteError> {
// let size = (self.len() as u16).to_be_bytes();
// stream.write_all(&size).await?;
// stream.write_all(self.as_ref()).await?;
// Ok(2 + self.len())
// }
// }

buf.put_u16(self.len() as u16);
buf.extend(self.as_bytes());
Ok(())
}
}
impl<S> MqttAsyncWrite<S> for String
where
S: tokio::io::AsyncWrite + Unpin,
{
async fn async_write(&self, stream: &mut S) -> Result<usize, crate::packets::error::WriteError> {
let size = (self.len() as u16).to_be_bytes();
stream.write_all(&size).await?;
stream.write_all(self.as_bytes()).await?;
Ok(2 + self.len())
}
}

impl WireLength for String {
#[inline(always)]
fn wire_len(&self) -> usize {
self.len() + 2
}
}

impl MqttRead for Bytes {
#[inline]
fn read(buf: &mut Bytes) -> Result<Self, DeserializeError> {
if buf.len() < 2 {
return Err(DeserializeError::InsufficientData(std::any::type_name::<Bytes>(), buf.len(), 2));
}
let len = buf.get_u16() as usize;

if len > buf.len() {
return Err(DeserializeError::InsufficientData(std::any::type_name::<Bytes>(), buf.len(), len));
}

Ok(buf.split_to(len))
}
}
impl<S> MqttAsyncRead<S> for Bytes
where
S: tokio::io::AsyncReadExt + std::marker::Unpin,
{
async fn async_read(stream: &mut S) -> Result<(Self, usize), ReadError> {
let size = stream.read_u16().await? as usize;
// let mut data = BytesMut::with_capacity(size);
let mut data = Vec::with_capacity(size);
let read_bytes = stream.read_exact(&mut data).await?;
assert_eq!(size, read_bytes);
Ok((data.into(), 2 + size))
}
}
impl MqttWrite for Bytes {
#[inline]
fn write(&self, buf: &mut BytesMut) -> Result<(), SerializeError> {
buf.put_u16(self.len() as u16);
buf.extend(self);

Ok(())
}
}
impl<S> MqttAsyncWrite<S> for Bytes
where
S: tokio::io::AsyncWrite + Unpin,
{
async fn async_write(&self, stream: &mut S) -> Result<usize, crate::packets::error::WriteError> {
let size = (self.len() as u16).to_be_bytes();
stream.write_all(&size).await?;
stream.write_all(self.as_ref()).await?;
Ok(2 + self.len())
}
}

impl WireLength for Bytes {
#[inline(always)]
fn wire_len(&self) -> usize {
self.len() + 2
}
}
// impl WireLength for Bytes {
// #[inline(always)]
// fn wire_len(&self) -> usize {
// self.len() + 2
// }
// }

impl MqttRead for Vec<u8> {
#[inline]
Expand Down

0 comments on commit fa9ad2e

Please sign in to comment.