From 858e6c1a5716c3d50f21efbb6bbced2c388f97e9 Mon Sep 17 00:00:00 2001 From: Louis Thiery Date: Wed, 16 Aug 2023 14:32:43 -0700 Subject: [PATCH] Guard against indexing beyond a slice length (#83) * guard against indexing into short packets. improve some error msgs. * define PREFIX_LEN * simplify the const sizes; use the const instead of a magic number * fix comment on terminate fn --- src/packet/error.rs | 10 ++++--- src/packet/parser.rs | 70 ++++++++++++++++++++++++++------------------ src/packet/tx_ack.rs | 2 -- src/tests.rs | 1 - 4 files changed, 47 insertions(+), 36 deletions(-) diff --git a/src/packet/error.rs b/src/packet/error.rs index 8cb5fdf..885bdf5 100644 --- a/src/packet/error.rs +++ b/src/packet/error.rs @@ -13,10 +13,12 @@ pub enum Error { #[derive(Error, Debug)] pub enum ParseError { - #[error("invalid GWMP version")] - InvalidProtocolVersion, - #[error("invalid GWMP frame identifier")] - InvalidIdentifier, + #[error("invalid packet length: {0}. Requires at least {1} bytes")] + InvalidPacketLength(usize, usize), + #[error("invalid GWMP version: {0}")] + InvalidProtocolVersion(u8), + #[error("invalid GWMP frame identifier: {0}")] + InvalidIdentifier(u8), #[error("utf8 error")] Utf8(#[from] std::str::Utf8Error), #[error("invalid Json string for {identifier} frame: {json_str}. JsonError: {json_error}")] diff --git a/src/packet/parser.rs b/src/packet/parser.rs index 3581b82..b3f2e0a 100644 --- a/src/packet/parser.rs +++ b/src/packet/parser.rs @@ -1,33 +1,37 @@ use super::*; use crate::tx_ack::Data; -use std::convert::TryFrom; +use std::{convert::TryFrom, result::Result}; const PROTOCOL_VERSION_INDEX: usize = 0; const IDENTIFIER_INDEX: usize = 3; -const PACKET_PAYLOAD_START: usize = 8; +const PREFIX_LEN: usize = IDENTIFIER_INDEX + 1; +const GATEWAY_MAC_LEN: usize = 8; fn random_token(buffer: &[u8]) -> u16 { (buffer[1] as u16) << 8 | buffer[2] as u16 } -pub fn gateway_mac(buffer: &[u8]) -> MacAddress { - MacAddress::new( - buffer[0], buffer[1], buffer[2], buffer[3], buffer[4], buffer[5], buffer[6], buffer[7], - ) -} - -pub trait Parser { - fn parse(buffer: &[u8]) -> std::result::Result; +pub fn gateway_mac(buffer: &[u8]) -> Result { + if buffer.len() < GATEWAY_MAC_LEN { + Err(ParseError::InvalidPacketLength( + buffer.len(), + GATEWAY_MAC_LEN, + )) + } else { + Ok(MacAddress::new( + buffer[0], buffer[1], buffer[2], buffer[3], buffer[4], buffer[5], buffer[6], buffer[7], + )) + } } impl Packet { - pub fn parse_uplink(buffer: &[u8]) -> std::result::Result { + pub fn parse_uplink(buffer: &[u8]) -> Result { match Self::parse(buffer)? { Packet::Up(up) => Ok(up), Packet::Down(down) => Err(ParseError::UnexpectedDownlink(down)), } } - pub fn parse_downlink(buffer: &[u8]) -> std::result::Result { + pub fn parse_downlink(buffer: &[u8]) -> Result { match Self::parse(buffer)? { Packet::Down(down) => Ok(down), Packet::Up(up) => Err(ParseError::UnexpectedUplink(Box::new(up))), @@ -35,21 +39,29 @@ impl Packet { } } -impl Parser for Packet { - fn parse(buffer: &[u8]) -> std::result::Result { - if buffer[PROTOCOL_VERSION_INDEX] != PROTOCOL_VERSION { - return Err(ParseError::InvalidProtocolVersion); +impl Packet { + pub fn parse(buffer: &[u8]) -> Result { + if buffer.len() < PREFIX_LEN { + return Err(ParseError::InvalidPacketLength(buffer.len(), PREFIX_LEN)); + } + + let protocol_version = buffer[PROTOCOL_VERSION_INDEX]; + if protocol_version != PROTOCOL_VERSION { + return Err(ParseError::InvalidProtocolVersion(protocol_version)); }; - match Identifier::try_from(buffer[IDENTIFIER_INDEX]) { - Err(_) => Err(ParseError::InvalidIdentifier), + let frame_identifier = buffer[IDENTIFIER_INDEX]; + match Identifier::try_from(frame_identifier) { + Err(_) => Err(ParseError::InvalidIdentifier(frame_identifier)), Ok(id) => { + // the token is before the identifier which we've already done a length check for let random_token = random_token(buffer); - let buffer = &buffer[4..]; + let buffer = &buffer[PREFIX_LEN..]; + Ok(match id { // up packets Identifier::PullData => { - let gateway_mac = gateway_mac(&buffer[..PACKET_PAYLOAD_START]); + let gateway_mac = gateway_mac(buffer)?; pull_data::Packet { random_token, gateway_mac, @@ -57,9 +69,9 @@ impl Parser for Packet { .into() } Identifier::PushData => { - let gateway_mac = gateway_mac(&buffer[..PACKET_PAYLOAD_START]); + let gateway_mac = gateway_mac(buffer)?; let json_str = - std::str::from_utf8(&buffer[PACKET_PAYLOAD_START..terminate(buffer)])?; + std::str::from_utf8(&buffer[GATEWAY_MAC_LEN..terminate(buffer)])?; let data = serde_json::from_str(json_str).map_err(|json_error| { ParseError::InvalidJson { identifier: id, @@ -75,16 +87,14 @@ impl Parser for Packet { .into() } Identifier::TxAck => { - let gateway_mac = gateway_mac(&buffer[..PACKET_PAYLOAD_START]); - let data = if buffer.len() > PACKET_PAYLOAD_START { + let gateway_mac = gateway_mac(buffer)?; + let data = if buffer.len() > GATEWAY_MAC_LEN { // guard against some packet forwarders that put a 0 byte as the last byte - if buffer.len() == PACKET_PAYLOAD_START + 1 - && buffer[PACKET_PAYLOAD_START] == 0 - { + if buffer.len() == GATEWAY_MAC_LEN + 1 && buffer[GATEWAY_MAC_LEN] == 0 { Data::default() } else { let json_str = std::str::from_utf8( - &buffer[PACKET_PAYLOAD_START..terminate(buffer)], + &buffer[GATEWAY_MAC_LEN..terminate(buffer)], )?; serde_json::from_str(json_str).map_err(|json_error| { ParseError::InvalidJson { @@ -126,7 +136,9 @@ impl Parser for Packet { // deals with null byte terminated json fn terminate(buf: &[u8]) -> usize { - if buf[buf.len() - 1] == 0 { + if buf.is_empty() { + 0 + } else if buf[buf.len() - 1] == 0 { buf.len() - 1 } else { buf.len() diff --git a/src/packet/tx_ack.rs b/src/packet/tx_ack.rs index 69aedb2..13a1527 100644 --- a/src/packet/tx_ack.rs +++ b/src/packet/tx_ack.rs @@ -347,7 +347,6 @@ fn tx_nack_tx_power_sx1302_ser() { #[test] fn null_terminate() { - use crate::packet::parser::Parser; let bytes = hex::decode("02904905aa555a00000000007b227478706b5f61636b223a7b227761726e223a2254585f504f574552222c2276616c7565223a32372c22746d7374223a333937353336363839317d7d00").unwrap(); println!("{bytes:?}"); let frame = crate::packet::Packet::parse(&bytes).unwrap(); @@ -356,7 +355,6 @@ fn null_terminate() { #[test] fn dont_null_terminate() { - use crate::packet::parser::Parser; let bytes = hex::decode("02904905aa555a00000000007b227478706b5f61636b223a7b227761726e223a2254585f504f574552222c2276616c7565223a32372c22746d7374223a333937353336363839317d7d").unwrap(); println!("{bytes:?}"); let frame = crate::packet::Packet::parse(&bytes).unwrap(); diff --git a/src/tests.rs b/src/tests.rs index 8245959..6b900a1 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,4 +1,3 @@ -use super::packet::parser::Parser; use super::*; #[test] fn test_pull_data() {