From 0ba020dcd945981dba9c1134f5e2ca22a3cd7b6e Mon Sep 17 00:00:00 2001 From: Jens Reidel Date: Wed, 19 Jul 2023 18:10:47 +0200 Subject: [PATCH] Add safety comments for all unsafe usage Signed-off-by: Jens Reidel --- src/proto.rs | 32 ++++++++++++++++++++++++++------ src/upgrade/client_request.rs | 1 + src/upgrade/server_response.rs | 2 ++ 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/proto.rs b/src/proto.rs index cb91ffc4d66..98e6e352093 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -371,6 +371,8 @@ impl Message { if data.is_empty() { Ok(Self { opcode, data }) } else { + // SAFETY: The Decoder ensures that close frames consist of at least two bytes + // A conversion from two u8s to a u16 cannot fail. let close_code_value = u16::from_be_bytes(unsafe { data.get_unchecked(0..2).try_into().unwrap_unchecked() }); @@ -383,6 +385,8 @@ impl Message { // Verify that the reason is allowed if data.len() > 2 { + // SAFETY: The Decoder ensures that close frames consist of at least two + // bytes utf8::parse_str(unsafe { data.get_unchecked(2..) })?; } @@ -507,6 +511,8 @@ impl Message { /// binary or binary and invalid UTF-8. pub fn as_text(&self) -> Result<&str, ProtocolError> { match self.opcode { + // SAFETY: UTF-8 is validated by the Decoder and/or when the message is assembled from + // frames in the case of text messages. OpCode::Text => Ok(unsafe { std::str::from_utf8_unchecked(&self.data) }), OpCode::Binary => Ok(utf8::parse_str(&self.data)?), _ => Err(ProtocolError::MessageHasWrongOpcode), @@ -522,6 +528,7 @@ impl Message { pub fn as_close(&self) -> Result<(Option, Option<&str>), ProtocolError> { if self.opcode == OpCode::Close { let close_code = if self.data.len() >= 2 { + // SAFETY: self.data.len() is greater or equal to 2 let close_code_value = u16::from_be_bytes(unsafe { self.data.get_unchecked(0..2).try_into().unwrap_unchecked() }); @@ -531,6 +538,7 @@ impl Message { }; let reason = if self.data.len() > 2 { + // SAFETY: self.data.len() is greater or equal to 2 Some(unsafe { std::str::from_utf8_unchecked(self.data.get_unchecked(2..)) }) } else { None @@ -767,6 +775,8 @@ where self.partial_payload.extend_from_slice(&frame.payload); if self.partial_opcode == OpCode::Text { + // SAFETY: self.utf8_valid_up_to is an index in self.partial_payload and cannot + // exceed its length let (should_fail, valid_up_to) = utf8::should_fail_fast( unsafe { self.partial_payload.get_unchecked(self.utf8_valid_up_to..) }, frame.is_final, @@ -828,7 +838,7 @@ where break; }; - // We know that the pending_messages are not empty + // SAFETY: We just ensured that the pending_messages are not empty let item = unsafe { self.pending_messages.pop_front().unwrap_unchecked() }; // Encode it into the buffer @@ -1031,7 +1041,7 @@ impl Encoder for WebsocketProtocol { } #[cfg(not(feature = "client"))] { - // This allows for making the dependency on random generators + // SAFETY: This allows for making the dependency on random generators // only required for clients, servers can avoid it entirely. // Since it is not possible to create a stream with client role // without the client builder (and that is locked behind the client feature), @@ -1066,6 +1076,8 @@ impl Encoder for WebsocketProtocol { if let Some(mask) = mask { let start_of_data = dst.len() - chunk.len(); + // SAFETY: We called dst.extend_from_slice(chunk), so start_of_data is an index + // in dst, to be exact, the lenth of dst before the extend_from_slice call mask::frame(&mask, unsafe { dst.get_unchecked_mut(start_of_data..) }, 0); } @@ -1103,6 +1115,7 @@ impl Decoder for WebsocketProtocol { // Opcode and payload length must be present ensure_buffer_has_space!(src, 2); + // SAFETY: The ensure_buffer_has_space call has validated this let fin_and_rsv = unsafe { src.get_unchecked(0) }; let payload_len_1 = unsafe { src.get_unchecked(1) }; @@ -1142,12 +1155,16 @@ impl Decoder for WebsocketProtocol { if payload_length == 126 { ensure_buffer_has_space!(src, 4); + // SAFETY: The ensure_buffer_has_space call has validated this + // A conversion from two u8s to a u16 cannot fail payload_length = u16::from_be_bytes(unsafe { src.get_unchecked(2..4).try_into().unwrap_unchecked() }) as usize; offset = 4; } else if payload_length == 127 { ensure_buffer_has_space!(src, 10); + // SAFETY: The ensure_buffer_has_space call has validated this + // A conversion from 8 u8s to a u64 cannot fail payload_length = u64::from_be_bytes(unsafe { src.get_unchecked(2..10).try_into().unwrap_unchecked() }) as usize; @@ -1188,8 +1205,8 @@ impl Decoder for WebsocketProtocol { if mask { let unmasked_until = offset + self.payload_in; - // This is very unsafe, but sound because the masking key - // and the payload do not overlap in src + // SAFETY: The masking key and the payload do not overlap in src + // TODO: Replace with split_at_mut_unchecked when stable let (masking_key, to_unmask) = unsafe { let masking_key_ptr = src.get_unchecked(offset - 4..offset) as *const [u8]; @@ -1205,6 +1222,9 @@ impl Decoder for WebsocketProtocol { self.payload_in = data_available; + // SAFETY: offset + utf8_valid_up_to is the index until which utf8 was + // validated for this frame and therefore guaranteed to be in bounds. + // self.payload_in is data_available, which is at most src.len() let (should_fail, valid_up_to) = utf8::should_fail_fast( unsafe { src.get_unchecked( @@ -1237,8 +1257,8 @@ impl Decoder for WebsocketProtocol { if mask { let unmasked_until = offset + self.payload_in; - // This is very unsafe, but sound because the masking key - // and the payload do not overlap in src + // SAFETY: The masking key and the payload do not overlap in src + // TODO: Replace with split_at_mut_unchecked when stable let (masking_key, to_unmask) = unsafe { let masking_key_ptr = src.get_unchecked(offset - 4..offset) as *const [u8]; let to_unmask_ptr = src diff --git a/src/upgrade/client_request.rs b/src/upgrade/client_request.rs index f9c9f122989..345e4d6768c 100644 --- a/src/upgrade/client_request.rs +++ b/src/upgrade/client_request.rs @@ -18,6 +18,7 @@ fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool { } while haystack.len() >= needle.len() { + // SAFETY: needle.len() will always be equal to or less than haystack.len() if unsafe { haystack.get_unchecked(..needle.len()) }.eq_ignore_ascii_case(needle) { return true; } diff --git a/src/upgrade/server_response.rs b/src/upgrade/server_response.rs index a68a2cbf3ff..bfbce8119bd 100644 --- a/src/upgrade/server_response.rs +++ b/src/upgrade/server_response.rs @@ -104,6 +104,8 @@ impl Encoder<()> for Codec { type Error = crate::Error; fn encode(&mut self, _item: (), _dst: &mut BytesMut) -> Result<(), Self::Error> { + // SAFETY: This is never called. Encoder is implemented to satisfy requirements + // for Framed. unsafe { unreachable_unchecked() } } }