Skip to content

Commit

Permalink
Add safety comments for all unsafe usage
Browse files Browse the repository at this point in the history
Signed-off-by: Jens Reidel <adrian@travitia.xyz>
  • Loading branch information
Gelbpunkt committed Jul 19, 2023
1 parent 05cd3d1 commit 0ba020d
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 6 deletions.
32 changes: 26 additions & 6 deletions src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ impl Message {
if data.is_empty() {
Ok(Self { opcode, data })
} else {
// SAFETY: The Decoder ensures that close frames consist of at least two bytes
// A conversion from two u8s to a u16 cannot fail.
let close_code_value = u16::from_be_bytes(unsafe {
data.get_unchecked(0..2).try_into().unwrap_unchecked()
});
Expand All @@ -383,6 +385,8 @@ impl Message {

// Verify that the reason is allowed
if data.len() > 2 {
// SAFETY: The Decoder ensures that close frames consist of at least two
// bytes
utf8::parse_str(unsafe { data.get_unchecked(2..) })?;
}

Expand Down Expand Up @@ -507,6 +511,8 @@ impl Message {
/// binary or binary and invalid UTF-8.
pub fn as_text(&self) -> Result<&str, ProtocolError> {
match self.opcode {
// SAFETY: UTF-8 is validated by the Decoder and/or when the message is assembled from
// frames in the case of text messages.
OpCode::Text => Ok(unsafe { std::str::from_utf8_unchecked(&self.data) }),
OpCode::Binary => Ok(utf8::parse_str(&self.data)?),
_ => Err(ProtocolError::MessageHasWrongOpcode),
Expand All @@ -522,6 +528,7 @@ impl Message {
pub fn as_close(&self) -> Result<(Option<CloseCode>, Option<&str>), ProtocolError> {
if self.opcode == OpCode::Close {
let close_code = if self.data.len() >= 2 {
// SAFETY: self.data.len() is greater or equal to 2
let close_code_value = u16::from_be_bytes(unsafe {
self.data.get_unchecked(0..2).try_into().unwrap_unchecked()
});
Expand All @@ -531,6 +538,7 @@ impl Message {
};

let reason = if self.data.len() > 2 {
// SAFETY: self.data.len() is greater or equal to 2
Some(unsafe { std::str::from_utf8_unchecked(self.data.get_unchecked(2..)) })
} else {
None
Expand Down Expand Up @@ -767,6 +775,8 @@ where
self.partial_payload.extend_from_slice(&frame.payload);

if self.partial_opcode == OpCode::Text {
// SAFETY: self.utf8_valid_up_to is an index in self.partial_payload and cannot
// exceed its length
let (should_fail, valid_up_to) = utf8::should_fail_fast(
unsafe { self.partial_payload.get_unchecked(self.utf8_valid_up_to..) },
frame.is_final,
Expand Down Expand Up @@ -828,7 +838,7 @@ where
break;
};

// We know that the pending_messages are not empty
// SAFETY: We just ensured that the pending_messages are not empty
let item = unsafe { self.pending_messages.pop_front().unwrap_unchecked() };

// Encode it into the buffer
Expand Down Expand Up @@ -1031,7 +1041,7 @@ impl Encoder<Message> for WebsocketProtocol {
}
#[cfg(not(feature = "client"))]
{
// This allows for making the dependency on random generators
// SAFETY: This allows for making the dependency on random generators
// only required for clients, servers can avoid it entirely.
// Since it is not possible to create a stream with client role
// without the client builder (and that is locked behind the client feature),
Expand Down Expand Up @@ -1066,6 +1076,8 @@ impl Encoder<Message> for WebsocketProtocol {

if let Some(mask) = mask {
let start_of_data = dst.len() - chunk.len();
// SAFETY: We called dst.extend_from_slice(chunk), so start_of_data is an index
// in dst, to be exact, the lenth of dst before the extend_from_slice call
mask::frame(&mask, unsafe { dst.get_unchecked_mut(start_of_data..) }, 0);
}

Expand Down Expand Up @@ -1103,6 +1115,7 @@ impl Decoder for WebsocketProtocol {
// Opcode and payload length must be present
ensure_buffer_has_space!(src, 2);

// SAFETY: The ensure_buffer_has_space call has validated this
let fin_and_rsv = unsafe { src.get_unchecked(0) };
let payload_len_1 = unsafe { src.get_unchecked(1) };

Expand Down Expand Up @@ -1142,12 +1155,16 @@ impl Decoder for WebsocketProtocol {

if payload_length == 126 {
ensure_buffer_has_space!(src, 4);
// SAFETY: The ensure_buffer_has_space call has validated this
// A conversion from two u8s to a u16 cannot fail
payload_length = u16::from_be_bytes(unsafe {
src.get_unchecked(2..4).try_into().unwrap_unchecked()
}) as usize;
offset = 4;
} else if payload_length == 127 {
ensure_buffer_has_space!(src, 10);
// SAFETY: The ensure_buffer_has_space call has validated this
// A conversion from 8 u8s to a u64 cannot fail
payload_length = u64::from_be_bytes(unsafe {
src.get_unchecked(2..10).try_into().unwrap_unchecked()
}) as usize;
Expand Down Expand Up @@ -1188,8 +1205,8 @@ impl Decoder for WebsocketProtocol {
if mask {
let unmasked_until = offset + self.payload_in;

// This is very unsafe, but sound because the masking key
// and the payload do not overlap in src
// SAFETY: The masking key and the payload do not overlap in src
// TODO: Replace with split_at_mut_unchecked when stable
let (masking_key, to_unmask) = unsafe {
let masking_key_ptr =
src.get_unchecked(offset - 4..offset) as *const [u8];
Expand All @@ -1205,6 +1222,9 @@ impl Decoder for WebsocketProtocol {

self.payload_in = data_available;

// SAFETY: offset + utf8_valid_up_to is the index until which utf8 was
// validated for this frame and therefore guaranteed to be in bounds.
// self.payload_in is data_available, which is at most src.len()
let (should_fail, valid_up_to) = utf8::should_fail_fast(
unsafe {
src.get_unchecked(
Expand Down Expand Up @@ -1237,8 +1257,8 @@ impl Decoder for WebsocketProtocol {
if mask {
let unmasked_until = offset + self.payload_in;

// This is very unsafe, but sound because the masking key
// and the payload do not overlap in src
// SAFETY: The masking key and the payload do not overlap in src
// TODO: Replace with split_at_mut_unchecked when stable
let (masking_key, to_unmask) = unsafe {
let masking_key_ptr = src.get_unchecked(offset - 4..offset) as *const [u8];
let to_unmask_ptr = src
Expand Down
1 change: 1 addition & 0 deletions src/upgrade/client_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
}

while haystack.len() >= needle.len() {
// SAFETY: needle.len() will always be equal to or less than haystack.len()
if unsafe { haystack.get_unchecked(..needle.len()) }.eq_ignore_ascii_case(needle) {
return true;
}
Expand Down
2 changes: 2 additions & 0 deletions src/upgrade/server_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ impl Encoder<()> for Codec {
type Error = crate::Error;

fn encode(&mut self, _item: (), _dst: &mut BytesMut) -> Result<(), Self::Error> {
// SAFETY: This is never called. Encoder is implemented to satisfy requirements
// for Framed.
unsafe { unreachable_unchecked() }
}
}

0 comments on commit 0ba020d

Please sign in to comment.