diff --git a/web-transport-proto/src/capsule.rs b/web-transport-proto/src/capsule.rs index baeb654..cd1b1ab 100644 --- a/web-transport-proto/src/capsule.rs +++ b/web-transport-proto/src/capsule.rs @@ -15,72 +15,82 @@ const MAX_MESSAGE_SIZE: usize = 1024; #[derive(Debug, Clone, PartialEq, Eq)] pub enum Capsule { CloseWebTransportSession { code: u32, reason: String }, + Grease { num: u64 }, Unknown { typ: VarInt, payload: Bytes }, } impl Capsule { pub fn decode(buf: &mut B) -> Result { - loop { - let typ = VarInt::decode(buf)?; - let length = VarInt::decode(buf)?; + let typ = VarInt::decode(buf)?; + let length = VarInt::decode(buf)?; - let mut payload = buf.take(length.into_inner() as usize); - if payload.remaining() > MAX_MESSAGE_SIZE { - return Err(CapsuleError::MessageTooLong); - } + let mut payload = buf.take(length.into_inner() as usize); - if payload.remaining() < payload.limit() { - return Err(CapsuleError::UnexpectedEnd); - } + // Check declared length first - reject immediately if too large + if payload.limit() > MAX_MESSAGE_SIZE { + return Err(CapsuleError::MessageTooLong); + } - match typ.into_inner() { - CLOSE_WEBTRANSPORT_SESSION_TYPE => { - if payload.remaining() < 4 { - return Err(CapsuleError::UnexpectedEnd); - } + // Then check if all declared bytes are buffered + if payload.remaining() < payload.limit() { + return Err(CapsuleError::UnexpectedEnd); + } - let error_code = payload.get_u32(); + let typ_val = typ.into_inner(); - let message_len = payload.remaining(); - if message_len > MAX_MESSAGE_SIZE { - return Err(CapsuleError::MessageTooLong); - } + if let Some(num) = is_grease(typ_val) { + payload.advance(payload.remaining()); + return Ok(Self::Grease { num }); + } - let mut message_bytes = vec![0u8; message_len]; - payload.copy_to_slice(&mut message_bytes); + match typ_val { + CLOSE_WEBTRANSPORT_SESSION_TYPE => { + if payload.remaining() < 4 { + return Err(CapsuleError::UnexpectedEnd); + } - let error_message = - String::from_utf8(message_bytes).map_err(|_| CapsuleError::InvalidUtf8)?; + let error_code = payload.get_u32(); - return Ok(Self::CloseWebTransportSession { - code: error_code, - reason: error_message, - }); - } - t if is_grease(t) => continue, - _ => { - // Unknown capsule type - store it - let mut payload_bytes = vec![0u8; payload.remaining()]; - payload.copy_to_slice(&mut payload_bytes); - return Ok(Self::Unknown { - typ, - payload: Bytes::from(payload_bytes), - }); + let message_len = payload.remaining(); + if message_len > MAX_MESSAGE_SIZE { + return Err(CapsuleError::MessageTooLong); } + + let mut message_bytes = vec![0u8; message_len]; + payload.copy_to_slice(&mut message_bytes); + + let error_message = + String::from_utf8(message_bytes).map_err(|_| CapsuleError::InvalidUtf8)?; + + Ok(Self::CloseWebTransportSession { + code: error_code, + reason: error_message, + }) + } + _ => { + let mut payload_bytes = vec![0u8; payload.remaining()]; + payload.copy_to_slice(&mut payload_bytes); + Ok(Self::Unknown { + typ, + payload: Bytes::from(payload_bytes), + }) } } } - pub async fn read(stream: &mut S) -> Result { + pub async fn read(stream: &mut S) -> Result, CapsuleError> { let mut buf = Vec::new(); loop { if stream.read_buf(&mut buf).await? == 0 { + if buf.is_empty() { + return Ok(None); + } return Err(CapsuleError::UnexpectedEnd); } let mut limit = std::io::Cursor::new(&buf); match Self::decode(&mut limit) { - Ok(capsule) => return Ok(capsule), + Ok(capsule) => return Ok(Some(capsule)), Err(CapsuleError::UnexpectedEnd) => continue, Err(e) => return Err(e), } @@ -108,6 +118,19 @@ impl Capsule { // Encode the error message buf.put_slice(error_message.as_bytes()); } + Self::Grease { num } => { + // Generate grease type: 0x29 * N + 0x17 + // Check for overflow + let grease_type = num + .checked_mul(0x29) + .and_then(|v| v.checked_add(0x17)) + .expect("grease num value would overflow u64"); + + VarInt::from_u64(grease_type).unwrap().encode(buf); + + // Grease capsules have zero-length payload + VarInt::from_u32(0).encode(buf); + } Self::Unknown { typ, payload } => { // Encode the capsule type typ.encode(buf); @@ -129,13 +152,17 @@ impl Capsule { } } -fn is_grease(val: u64) -> bool { - if val < 0x21 { - return false; +// RFC 9297 Section 5.4: Capsule types of the form 0x29 * N + 0x17 +// Returns Some(N) if the value is a grease type, None otherwise +fn is_grease(val: u64) -> Option { + if val < 0x17 { + return None; } - #[allow(unknown_lints, clippy::manual_is_multiple_of)] - { - (val - 0x21) % 0x1f == 0 + let num = (val - 0x17) / 0x29; + if val == 0x29 * num + 0x17 { + Some(num) + } else { + None } } @@ -311,4 +338,39 @@ mod tests { assert_eq!(capsule, decoded); assert_eq!(read_buf.len(), 0); } + + #[test] + fn test_grease_capsule() { + // Test grease formula: 0x29 * N + 0x17 + for num in [0, 1, 5, 100, 1000] { + let capsule = Capsule::Grease { num }; + + let mut buf = Vec::new(); + capsule.encode(&mut buf); + + let mut read_buf = buf.as_slice(); + let decoded = Capsule::decode(&mut read_buf).unwrap(); + + assert_eq!(capsule, decoded); + assert_eq!(read_buf.len(), 0); + } + } + + #[test] + fn test_grease_values() { + // Verify specific grease type values + assert_eq!(is_grease(0x17), Some(0)); // N=0 + assert_eq!(is_grease(0x40), Some(1)); // N=1: 0x29 + 0x17 = 0x40 + assert_eq!(is_grease(0x69), Some(2)); // N=2: 0x29*2 + 0x17 = 0x69 + assert_eq!(is_grease(0x18), None); // Not a grease value + assert_eq!(is_grease(0x41), None); // Not a grease value + } + + #[test] + #[should_panic(expected = "grease num value would overflow u64")] + fn test_grease_overflow() { + let capsule = Capsule::Grease { num: u64::MAX }; + let mut buf = Vec::new(); + capsule.encode(&mut buf); + } } diff --git a/web-transport-quiche/src/connection.rs b/web-transport-quiche/src/connection.rs index fbc37d6..5d0f7ef 100644 --- a/web-transport-quiche/src/connection.rs +++ b/web-transport-quiche/src/connection.rs @@ -112,15 +112,23 @@ impl Connection { loop { match web_transport_proto::Capsule::read(&mut recv).await { - Ok(web_transport_proto::Capsule::CloseWebTransportSession { code, reason }) => { + Ok(Some(web_transport_proto::Capsule::CloseWebTransportSession { + code, + reason, + })) => { // TODO We shouldn't be closing the QUIC connection with the same error. // Instead, we should return it to the application. self.close(code, &reason); return; } - Ok(web_transport_proto::Capsule::Unknown { typ, payload }) => { + Ok(Some(web_transport_proto::Capsule::Grease { .. })) => {} + Ok(Some(web_transport_proto::Capsule::Unknown { typ, payload })) => { tracing::warn!("unknown capsule: type={typ} size={}", payload.len()); } + Ok(None) => { + // Stream closed without capsule + return; + } Err(_) => { self.close(500, "capsule error"); return; diff --git a/web-transport-quinn/src/session.rs b/web-transport-quinn/src/session.rs index 880a77d..20484ea 100644 --- a/web-transport-quinn/src/session.rs +++ b/web-transport-quinn/src/session.rs @@ -103,12 +103,19 @@ impl Session { async fn run_closed(&mut self, mut connect: ConnectComplete) -> (u32, String) { loop { match web_transport_proto::Capsule::read(&mut connect.recv).await { - Ok(web_transport_proto::Capsule::CloseWebTransportSession { code, reason }) => { + Ok(Some(web_transport_proto::Capsule::CloseWebTransportSession { + code, + reason, + })) => { return (code, reason); } - Ok(web_transport_proto::Capsule::Unknown { typ, payload }) => { + Ok(Some(web_transport_proto::Capsule::Grease { .. })) => {} + Ok(Some(web_transport_proto::Capsule::Unknown { typ, payload })) => { tracing::warn!(%typ, size = payload.len(), "unknown capsule"); } + Ok(None) => { + return (0, "stream closed".to_string()); + } Err(_) => { return (1, "capsule error".to_string()); }