diff --git a/quinn-proto/src/connection/mod.rs b/quinn-proto/src/connection/mod.rs index 3d7e91c70..9c62aad53 100644 --- a/quinn-proto/src/connection/mod.rs +++ b/quinn-proto/src/connection/mod.rs @@ -98,6 +98,9 @@ use timer::{Timer, TimerTable}; mod transmit_buf; use transmit_buf::TransmitBuf; +mod state; +use state::{State, StateType}; + /// Protocol state and logic for a single QUIC connection /// /// Objects of this type receive [`ConnectionEvent`]s and emit [`EndpointEvent`]s and application @@ -205,8 +208,6 @@ pub struct Connection { timers: TimerTable, /// Number of packets received which could not be authenticated authentication_failures: u64, - /// Why the connection was lost, if it has been - error: Option, // // Queued non-retransmittable 1-RTT data @@ -342,7 +343,7 @@ impl Connection { }; #[cfg(not(test))] let data_space = PacketSpace::new(now, SpaceId::Data, &mut rng); - let state = State::Handshake(state::Handshake { + let state = State::handshake(state::Handshake { rem_cid_set: side.is_server(), expected_token: Bytes::new(), client_hello: None, @@ -409,7 +410,6 @@ impl Connection { }, timers: TimerTable::default(), authentication_failures: 0, - error: None, close: false, ack_frequency: AckFrequencyState::new(get_max_ack_delay( @@ -486,8 +486,8 @@ impl Connection { return Some(Event::Stream(event)); } - if let Some(err) = self.error.take() { - return Some(Event::ConnectionLost { reason: err }); + if let Some(reason) = self.state.take_error() { + return Some(Event::ConnectionLost { reason }); } None @@ -880,12 +880,12 @@ impl Connection { // Once there's nothing more to send on the AVAILABLE paths, do the same for BACKUP paths // Check whether we need to send a close message - let close = match self.state { - State::Drained => { + let close = match self.state.as_type() { + StateType::Drained => { self.app_limited = true; return None; } - State::Draining | State::Closed(_) => { + StateType::Draining | StateType::Closed => { // self.close is only reset once the associated packet had been // encoded successfully if !self.close { @@ -1209,8 +1209,10 @@ impl Connection { ); if frame::ConnectionClose::SIZE_BOUND < builder.frame_space_remaining() { let max_frame_size = builder.frame_space_remaining(); - match self.state { - State::Closed(state::Closed { ref reason }) => { + match self.state.as_type() { + StateType::Closed => { + let reason: Close = + self.state.as_closed().expect("checked").clone().into(); if space_id == SpaceId::Data || reason.is_transport_layer() { reason.encode(&mut builder.frame_space_mut(), max_frame_size) } else { @@ -1222,7 +1224,7 @@ impl Connection { .encode(&mut builder.frame_space_mut(), max_frame_size) } } - State::Draining => frame::ConnectionClose { + StateType::Draining => frame::ConnectionClose { error_code: TransportErrorCode::NO_ERROR, frame_type: None, reason: Bytes::new(), @@ -1781,7 +1783,7 @@ impl Connection { match timer { Timer::Conn(timer) => match timer { ConnTimer::Close => { - self.state = State::Drained; + self.state.move_to_drained(None); self.endpoint_events.push_back(EndpointEventInner::Drained); } ConnTimer::Idle => { @@ -1952,7 +1954,7 @@ impl Connection { self.close_common(); self.set_close_timer(now); self.close = true; - self.state = State::Closed(state::Closed { reason }); + self.state.move_to_closed_local(reason); } } @@ -2991,7 +2993,7 @@ impl Connection { // the real server. From now on we should no longer allow the server to migrate // its address. if space_id == SpaceId::Handshake { - if let State::Handshake(ref mut hs) = self.state { + if let Some(hs) = self.state.as_handshake_mut() { hs.allow_server_migration = false; } } @@ -3107,11 +3109,10 @@ impl Connection { let path_id = PathId::ZERO; self.path_data_mut(path_id).total_recvd = len as u64; - match self.state { - State::Handshake(ref mut state) => { - state.expected_token = packet.header.token.clone(); - } - _ => unreachable!("first packet must be delivered in Handshake state"), + if let Some(hs) = self.state.as_handshake_mut() { + hs.expected_token = packet.header.token.clone(); + } else { + unreachable!("first packet must be delivered in Handshake state"); } // The first packet is always on PathId::ZERO @@ -3251,9 +3252,9 @@ impl Connection { } let offset = self.spaces[space].crypto_offset; let outgoing = Bytes::from(outgoing); - if let State::Handshake(ref mut state) = self.state { + if let Some(hs) = self.state.as_handshake_mut() { if space == SpaceId::Initial && offset == 0 && self.side.is_client() { - state.client_hello = Some(outgoing.clone()); + hs.client_hello = Some(outgoing.clone()); } } self.spaces[space].crypto_offset += outgoing.len() as u64; @@ -3401,15 +3402,17 @@ impl Connection { return; } if remote != self.path_data_mut(path_id).remote { - match self.state { - State::Handshake(ref hs) if hs.allow_server_migration => { + if let Some(hs) = self.state.as_handshake() { + if hs.allow_server_migration { trace!(?remote, prev = ?self.path_data(path_id).remote, "server migrated to new remote"); self.path_data_mut(path_id).remote = remote; - } - _ => { + } else { debug!("discarding packet with unexpected remote during handshake"); return; } + } else { + debug!("discarding packet with unexpected remote during handshake"); + return; } } } @@ -3467,7 +3470,7 @@ impl Connection { return; } else { if let Header::Initial(InitialHeader { ref token, .. }) = packet.header { - if let State::Handshake(ref hs) = self.state { + if let Some(hs) = self.state.as_handshake() { if self.side.is_server() && token != &hs.expected_token { // Clients must send the same retry token in every Initial. Initial // packets can be spoofed, so we discard rather than killing the @@ -3508,23 +3511,26 @@ impl Connection { // State transitions for error cases if let Err(conn_err) = result { - self.error = Some(conn_err.clone()); - self.state = match conn_err { - ConnectionError::ApplicationClosed(reason) => State::closed(reason), - ConnectionError::ConnectionClosed(reason) => State::closed(reason), + match conn_err { + ConnectionError::ApplicationClosed(reason) => self.state.move_to_closed(reason), + ConnectionError::ConnectionClosed(reason) => self.state.move_to_closed(reason), ConnectionError::Reset | ConnectionError::TransportError(TransportError { code: TransportErrorCode::AEAD_LIMIT_REACHED, .. - }) => State::Drained, + }) => { + self.state.move_to_drained(Some(conn_err)); + } ConnectionError::TimedOut => { unreachable!("timeouts aren't generated by packet processing"); } ConnectionError::TransportError(err) => { debug!("closing connection due to transport error: {}", err); - State::closed(err) + self.state.move_to_closed(err); + } + ConnectionError::VersionMismatch => { + self.state.move_to_draining(Some(conn_err)); } - ConnectionError::VersionMismatch => State::Draining, ConnectionError::LocallyClosed => { unreachable!("LocallyClosed isn't generated by packet processing"); } @@ -3548,7 +3554,7 @@ impl Connection { } // Transmit CONNECTION_CLOSE if necessary - if let State::Closed(_) = self.state { + if matches!(self.state.as_type(), StateType::Closed) { // If there is no PathData for this PathId the packet was for a brand new // path. It was a valid packet however, so the remote is valid and we want to // send CONNECTION_CLOSE. @@ -3576,8 +3582,8 @@ impl Connection { trace!(%path_id, ?number, "discarding packet for unknown path"); return Ok(()); } - let state = match self.state { - State::Established => { + let state = match self.state.as_type() { + StateType::Established => { match packet.header.space() { SpaceId::Data => { self.process_payload(now, remote, path_id, number.unwrap(), packet)? @@ -3591,7 +3597,7 @@ impl Connection { } return Ok(()); } - State::Closed(_) => { + StateType::Closed => { for result in frame::Iter::new(packet.payload.freeze())? { let frame = match result { Ok(frame) => frame, @@ -3607,16 +3613,16 @@ impl Connection { self.stats.frame_rx.record(&frame); - if let Frame::Close(_) = frame { + if let Frame::Close(_error) = frame { trace!("draining"); - self.state = State::Draining; + self.state.move_to_draining(None); break; } } return Ok(()); } - State::Draining | State::Drained => return Ok(()), - State::Handshake(ref mut state) => state, + StateType::Draining | StateType::Drained => return Ok(()), + StateType::Handshake => self.state.as_handshake_mut().expect("checked"), }; match packet.header { @@ -3706,7 +3712,7 @@ impl Connection { }; *token = packet.payload.freeze().split_to(token_len); - self.state = State::Handshake(state::Handshake { + self.state = State::handshake(state::Handshake { expected_token: Bytes::new(), rem_cid_set: false, client_hello: None, @@ -3790,7 +3796,7 @@ impl Connection { } self.events.push_back(Event::Connected); - self.state = State::Established; + self.state.move_to_established(); trace!("established"); // Multipath can only be enabled after the state has reached Established. @@ -3814,7 +3820,7 @@ impl Connection { self.rem_handshake_cid = rem_cid; self.orig_rem_cid = rem_cid; state.rem_cid_set = true; - self.state = State::Handshake(state); + self.state.move_to_handshake(state); } else if rem_cid != self.rem_handshake_cid { debug!( "discarding packet with mismatched remote CID: {} != {}", @@ -3912,12 +3918,10 @@ impl Connection { self.on_path_ack_received(now, packet.header.space(), ack)?; } Frame::Close(reason) => { - self.error = Some(reason.into()); - self.state = State::Draining; + self.state.move_to_draining(Some(reason.into())); return Ok(()); } _ => { - dbg!(&frame); let mut err = TransportError::PROTOCOL_VIOLATION("illegal frame type in handshake"); err.frame = Some(frame.ty()); @@ -4606,8 +4610,7 @@ impl Connection { self.streams.queue_max_stream_id(pending); if let Some(reason) = close { - self.error = Some(reason.into()); - self.state = State::Draining; + self.state.move_to_draining(Some(reason.into())); self.close = true; } @@ -5806,8 +5809,7 @@ impl Connection { /// Terminate the connection instantly, without sending a close packet fn kill(&mut self, reason: ConnectionError) { self.close_common(); - self.error = Some(reason); - self.state = State::Drained; + self.state.move_to_drained(Some(reason)); self.endpoint_events.push_back(EndpointEventInner::Drained); } @@ -6093,10 +6095,13 @@ impl ConnectionSide { fn remote_may_migrate(&self, state: &State) -> bool { match self { Self::Server { server_config } => server_config.migration, - Self::Client { .. } => match state { - State::Handshake(handshake) => handshake.allow_server_migration, - _ => false, - }, + Self::Client { .. } => { + if let Some(hs) = state.as_handshake() { + hs.allow_server_migration + } else { + false + } + } } } @@ -6273,83 +6278,6 @@ pub struct MultipathNotNegotiated { _private: (), } -#[allow(unreachable_pub)] // fuzzing only -#[derive(Clone)] -pub enum State { - Handshake(state::Handshake), - Established, - Closed(state::Closed), - Draining, - /// Waiting for application to call close so we can dispose of the resources - Drained, -} - -impl State { - fn closed>(reason: R) -> Self { - Self::Closed(state::Closed { - reason: reason.into(), - }) - } - - fn is_handshake(&self) -> bool { - matches!(*self, Self::Handshake(_)) - } - - fn is_established(&self) -> bool { - matches!(*self, Self::Established) - } - - fn is_closed(&self) -> bool { - matches!(*self, Self::Closed(_) | Self::Draining | Self::Drained) - } - - fn is_drained(&self) -> bool { - matches!(*self, Self::Drained) - } -} - -mod state { - use super::*; - - #[allow(unreachable_pub)] // fuzzing only - #[derive(Clone)] - pub struct Handshake { - /// Whether the remote CID has been set by the peer yet - /// - /// Always set for servers - pub(super) rem_cid_set: bool, - /// Stateless retry token received in the first Initial by a server. - /// - /// Must be present in every Initial. Always empty for clients. - pub(super) expected_token: Bytes, - /// First cryptographic message - /// - /// Only set for clients - pub(super) client_hello: Option, - /// Whether the server address is allowed to migrate - /// - /// We allow the server to migrate during the handshake as long as we have not - /// received an authenticated handshake packet: it can send a response from a - /// different address than we sent the initial to. This allows us to send the - /// initial packet over multiple paths - by means of an IPv6 ULA address that copies - /// the packets sent to it to multiple destinations - and accept one response. - /// - /// This is only ever set to true if for a client which hasn't yet received an - /// authenticated handshake packet. It is set back to false in - /// [`Connection::on_packet_authenticated`]. - /// - /// THIS IS NOT RFC 9000 COMPLIANT! A server is not allowed to migrate addresses, - /// other than using the preferred-address transport parameter. - pub(super) allow_server_migration: bool, - } - - #[allow(unreachable_pub)] // fuzzing only - #[derive(Clone)] - pub struct Closed { - pub(super) reason: Close, - } -} - /// Events of interest to the application #[derive(Debug)] pub enum Event { diff --git a/quinn-proto/src/connection/state.rs b/quinn-proto/src/connection/state.rs new file mode 100644 index 000000000..7b118de10 --- /dev/null +++ b/quinn-proto/src/connection/state.rs @@ -0,0 +1,343 @@ +use bytes::Bytes; + +use crate::frame::Close; +use crate::{ + ApplicationClose, ConnectionClose, ConnectionError, TransportError, TransportErrorCode, +}; + +#[allow(unreachable_pub)] // fuzzing only +#[derive(Debug, Clone)] +pub struct State { + /// Nested [`InnerState`] to enforce all state transitions are done in this module. + inner: InnerState, +} + +impl State { + pub(super) fn as_handshake_mut(&mut self) -> Option<&mut Handshake> { + if let InnerState::Handshake(ref mut hs) = self.inner { + Some(hs) + } else { + None + } + } + + pub(super) fn as_handshake(&self) -> Option<&Handshake> { + if let InnerState::Handshake(ref hs) = self.inner { + Some(hs) + } else { + None + } + } + + pub(super) fn as_closed(&self) -> Option<&CloseReason> { + if let InnerState::Closed { + ref remote_reason, .. + } = self.inner + { + Some(remote_reason) + } else { + None + } + } + + #[cfg(test)] + pub(super) fn established() -> Self { + Self { + inner: InnerState::Established, + } + } + + pub(super) fn handshake(hs: Handshake) -> Self { + Self { + inner: InnerState::Handshake(hs), + } + } + + pub(super) fn move_to_handshake(&mut self, hs: Handshake) { + self.inner = InnerState::Handshake(hs); + } + + pub(super) fn move_to_established(&mut self) { + self.inner = InnerState::Established; + } + + /// Moves to a draining state. + /// + /// Panics if the state was already drained. + pub(super) fn move_to_drained(&mut self, error: Option) { + let (error, is_local) = if let Some(error) = error { + (Some(error), false) + } else { + let error = match &mut self.inner { + InnerState::Draining { error, .. } => error.take(), + InnerState::Drained { .. } => panic!("invalid state transition drained -> drained"), + InnerState::Closed { error_read, .. } if *error_read => None, + InnerState::Closed { remote_reason, .. } => { + let error = match remote_reason.clone().into() { + ConnectionError::ConnectionClosed(close) => { + if close.error_code == TransportErrorCode::PROTOCOL_VIOLATION { + ConnectionError::TransportError(close.error_code.into()) + } else { + ConnectionError::ConnectionClosed(close) + } + } + e => e, + }; + Some(error) + } + InnerState::Handshake(_) | InnerState::Established => None, + }; + (error, self.is_local_close()) + }; + self.inner = InnerState::Drained { error, is_local }; + } + + /// Moves to a draining state. + /// + /// Panics if the state is already draining or drained. + pub(super) fn move_to_draining(&mut self, error: Option) { + assert!( + matches!( + self.inner, + InnerState::Handshake(_) | InnerState::Established | InnerState::Closed { .. } + ), + "invalid state transition {:?} -> draining", + self.as_type() + ); + let is_local = self.is_local_close(); + self.inner = InnerState::Draining { error, is_local }; + } + + fn is_local_close(&self) -> bool { + match self.inner { + InnerState::Handshake(_) => false, + InnerState::Established => false, + InnerState::Closed { is_local, .. } => is_local, + InnerState::Draining { is_local, .. } => is_local, + InnerState::Drained { is_local, .. } => is_local, + } + } + + /// Moves to a closed state after a remote error is received. + /// + /// Panics if the state is later than established. + pub(super) fn move_to_closed>(&mut self, reason: R) { + assert!( + matches!( + self.inner, + InnerState::Handshake(_) | InnerState::Established + ), + "invalid state transition {:?} -> closed", + self.as_type() + ); + self.inner = InnerState::Closed { + error_read: false, + remote_reason: reason.into(), + is_local: false, + }; + } + + /// Moves to a closed state after a local error. + /// + /// Panics if the state is later than established. + pub(super) fn move_to_closed_local>(&mut self, reason: R) { + assert!( + matches!( + self.inner, + InnerState::Handshake(_) | InnerState::Established + ), + "invalid state transition {:?} -> closed (local)", + self.as_type() + ); + self.inner = InnerState::Closed { + error_read: false, + remote_reason: reason.into(), + is_local: true, + }; + } + + pub(super) fn is_handshake(&self) -> bool { + matches!(self.inner, InnerState::Handshake(_)) + } + + pub(super) fn is_established(&self) -> bool { + matches!(self.inner, InnerState::Established) + } + + pub(super) fn is_closed(&self) -> bool { + matches!( + self.inner, + InnerState::Closed { .. } | InnerState::Draining { .. } | InnerState::Drained { .. } + ) + } + + pub(super) fn is_drained(&self) -> bool { + matches!(self.inner, InnerState::Drained { .. }) + } + + pub(super) fn take_error(&mut self) -> Option { + match &mut self.inner { + InnerState::Draining { error, is_local } => { + if !*is_local { + error.take() + } else { + None + } + } + InnerState::Drained { error, is_local } => { + if !*is_local { + error.take() + } else { + None + } + } + InnerState::Closed { + remote_reason, + is_local: local_reason, + error_read, + } => { + if *error_read { + None + } else { + *error_read = true; + if *local_reason { + None + } else { + Some(remote_reason.clone().into()) + } + } + } + InnerState::Handshake(_) | InnerState::Established => None, + } + } + + pub(super) fn as_type(&self) -> StateType { + match self.inner { + InnerState::Handshake(_) => StateType::Handshake, + InnerState::Established => StateType::Established, + InnerState::Closed { .. } => StateType::Closed, + InnerState::Draining { .. } => StateType::Draining, + InnerState::Drained { .. } => StateType::Drained, + } + } +} + +#[derive(Debug, Clone)] +pub(super) enum StateType { + Handshake, + Established, + Closed, + Draining, + Drained, +} + +#[derive(Debug, Clone)] +pub(super) enum CloseReason { + TransportError(TransportError), + Connection(ConnectionClose), + Application(ApplicationClose), +} + +impl From for CloseReason { + fn from(x: TransportError) -> Self { + Self::TransportError(x) + } +} +impl From for CloseReason { + fn from(x: ConnectionClose) -> Self { + Self::Connection(x) + } +} +impl From for CloseReason { + fn from(x: ApplicationClose) -> Self { + Self::Application(x) + } +} + +impl From for CloseReason { + fn from(value: Close) -> Self { + match value { + Close::Application(reason) => Self::Application(reason), + Close::Connection(reason) => Self::Connection(reason), + } + } +} + +impl From for ConnectionError { + fn from(value: CloseReason) -> Self { + match value { + CloseReason::TransportError(err) => Self::TransportError(err), + CloseReason::Connection(reason) => Self::ConnectionClosed(reason), + CloseReason::Application(reason) => Self::ApplicationClosed(reason), + } + } +} + +impl From for Close { + fn from(value: CloseReason) -> Self { + match value { + CloseReason::TransportError(err) => Self::Connection(err.into()), + CloseReason::Connection(reason) => Self::Connection(reason), + CloseReason::Application(reason) => Self::Application(reason), + } + } +} + +#[derive(Debug, Clone)] +enum InnerState { + Handshake(Handshake), + Established, + Closed { + /// The reason the remote closed the connection, or the reason we are sending to the remote. + remote_reason: CloseReason, + /// Set to true if we closed the connection locally. + is_local: bool, + /// Did we read this as error already? + error_read: bool, + }, + Draining { + /// Why the connection was lost, if it has been. + error: Option, + /// Set to true if we closed the connection locally. + is_local: bool, + }, + /// Waiting for application to call close so we can dispose of the resources. + Drained { + /// Why the connection was lost, if it has been. + error: Option, + /// Set to true if we closed the connection locally. + is_local: bool, + }, +} + +#[allow(unreachable_pub)] // fuzzing only +#[derive(Debug, Clone)] +pub struct Handshake { + /// Whether the remote CID has been set by the peer yet. + /// + /// Always set for servers. + pub(super) rem_cid_set: bool, + /// Stateless retry token received in the first Initial by a server. + /// + /// Must be present in every Initial. Always empty for clients. + pub(super) expected_token: Bytes, + /// First cryptographic message. + /// + /// Only set for clients. + pub(super) client_hello: Option, + /// Whether the server address is allowed to migrate. + /// + /// We allow the server to migrate during the handshake as long as we have not + /// received an authenticated handshake packet: it can send a response from a + /// different address than we sent the initial to. This allows us to send the + /// initial packet over multiple paths - by means of an IPv6 ULA address that copies + /// the packets sent to it to multiple destinations - and accept one response. + /// + /// This is only ever set to true if for a client which hasn't yet received an + /// authenticated handshake packet. It is set back to false in + /// [`super::Connection::on_packet_authenticated`]. + /// + /// THIS IS NOT RFC 9000 COMPLIANT! A server is not allowed to migrate addresses, + /// other than using the preferred-address transport parameter. + pub(super) allow_server_migration: bool, +} diff --git a/quinn-proto/src/connection/streams/state.rs b/quinn-proto/src/connection/streams/state.rs index 7837b0707..c87bcc09b 100644 --- a/quinn-proto/src/connection/streams/state.rs +++ b/quinn-proto/src/connection/streams/state.rs @@ -1280,7 +1280,7 @@ mod tests { ..TransportParameters::default() }); - let (mut pending, state) = (Retransmits::default(), ConnState::Established); + let (mut pending, state) = (Retransmits::default(), ConnState::established()); let id = Streams { state: &mut server, conn_state: &state, @@ -1341,7 +1341,7 @@ mod tests { ..TransportParameters::default() }); - let (mut pending, state) = (Retransmits::default(), ConnState::Established); + let (mut pending, state) = (Retransmits::default(), ConnState::established()); let mut streams = Streams { state: &mut server, conn_state: &state, @@ -1397,7 +1397,7 @@ mod tests { ..TransportParameters::default() }); - let (mut pending, state) = (Retransmits::default(), ConnState::Established); + let (mut pending, state) = (Retransmits::default(), ConnState::established()); let mut streams = Streams { state: &mut server, conn_state: &state, @@ -1467,7 +1467,7 @@ mod tests { ..TransportParameters::default() }); - let (mut pending, state) = (Retransmits::default(), ConnState::Established); + let (mut pending, state) = (Retransmits::default(), ConnState::established()); let mut streams = Streams { state: &mut server, conn_state: &state, @@ -1548,7 +1548,7 @@ mod tests { ..TransportParameters::default() }); - let (mut pending, state) = (Retransmits::default(), ConnState::Established); + let (mut pending, state) = (Retransmits::default(), ConnState::established()); let mut streams = Streams { state: &mut server, conn_state: &state, @@ -1654,7 +1654,7 @@ mod tests { initial_max_stream_data_uni: 42u32.into(), ..TransportParameters::default() }); - let (mut pending, state) = (Retransmits::default(), ConnState::Established); + let (mut pending, state) = (Retransmits::default(), ConnState::established()); let mut streams = Streams { state: &mut server, conn_state: &state, @@ -1957,7 +1957,7 @@ mod tests { assert_eq!(server.poll(), None); let mut retransmits = Retransmits::default(); - let conn_state = ConnState::Established; + let conn_state = ConnState::established(); let stream_id = Streams { state: &mut server, @@ -2032,7 +2032,7 @@ mod tests { assert_eq!(server.poll(), None); let mut retransmits = Retransmits::default(); - let conn_state = ConnState::Established; + let conn_state = ConnState::established(); let stream_id = Streams { state: &mut server,