diff --git a/quinn-proto/src/endpoint.rs b/quinn-proto/src/endpoint.rs index 36546448e..47b4394f9 100644 --- a/quinn-proto/src/endpoint.rs +++ b/quinn-proto/src/endpoint.rs @@ -573,7 +573,7 @@ impl Endpoint { if self.cids_exhausted() { debug!("refusing connection"); - self.index.remove_initial(incoming.orig_dst_cid); + self.index.remove_initial(dst_cid); return Err(AcceptError { cause: ConnectionError::CidsExhausted, response: Some(self.initial_close( @@ -602,7 +602,7 @@ impl Endpoint { .is_err() { debug!(packet_number, "failed to authenticate initial packet"); - self.index.remove_initial(incoming.orig_dst_cid); + self.index.remove_initial(dst_cid); return Err(AcceptError { cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(), response: None, @@ -651,9 +651,7 @@ impl Endpoint { transport_config, remote_address_validated, ); - if dst_cid.len() != 0 { - self.index.insert_initial(dst_cid, ch); - } + self.index.insert_initial(dst_cid, ch); match conn.handle_first_packet( now, @@ -802,7 +800,7 @@ impl Endpoint { /// Clean up endpoint data structures associated with an `Incoming`. fn clean_up_incoming(&mut self, incoming: &Incoming) { - self.index.remove_initial(incoming.orig_dst_cid); + self.index.remove_initial(incoming.packet.header.dst_cid); let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx); self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes; } @@ -864,6 +862,7 @@ impl Endpoint { cids_issued, loc_cids, addresses, + side, reset_token: None, }); debug_assert_eq!(id, ch.0, "connection handle allocation out of sync"); @@ -994,6 +993,8 @@ struct ConnectionIndex { /// Identifies connections based on the initial DCID the peer utilized /// /// Uses a standard `HashMap` to protect against hash collision attacks. + /// + /// Used by the server, not the client. connection_ids_initial: HashMap, /// Identifies connections based on locally created CIDs /// @@ -1022,17 +1023,27 @@ struct ConnectionIndex { impl ConnectionIndex { /// Associate an incoming connection with its initial destination CID fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) { + if dst_cid.len() == 0 { + return; + } self.connection_ids_initial .insert(dst_cid, RouteDatagramTo::Incoming(incoming_key)); } /// Remove an association with an initial destination CID fn remove_initial(&mut self, dst_cid: ConnectionId) { - self.connection_ids_initial.remove(&dst_cid); + if dst_cid.len() == 0 { + return; + } + let removed = self.connection_ids_initial.remove(&dst_cid); + debug_assert!(removed.is_some()); } /// Associate a connection with its initial destination CID fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) { + if dst_cid.len() == 0 { + return; + } self.connection_ids_initial .insert(dst_cid, RouteDatagramTo::Connection(connection)); } @@ -1070,8 +1081,8 @@ impl ConnectionIndex { /// Remove all references to a connection fn remove(&mut self, conn: &ConnectionMeta) { - if conn.init_cid.len() > 0 { - self.connection_ids_initial.remove(&conn.init_cid); + if conn.side.is_server() { + self.remove_initial(conn.init_cid); } for cid in conn.loc_cids.values() { self.connection_ids.remove(cid); @@ -1126,6 +1137,7 @@ pub(crate) struct ConnectionMeta { /// Only needed to support connections with zero-length CIDs, which cannot migrate, so we don't /// bother keeping it up to date. addresses: FourTuple, + side: Side, /// Reset token provided by the peer for the CID we're currently sending to, and the address /// being sent to reset_token: Option<(SocketAddr, ResetToken)>, diff --git a/quinn-proto/src/tests/mod.rs b/quinn-proto/src/tests/mod.rs index 2c940caa6..50bb6ef0e 100644 --- a/quinn-proto/src/tests/mod.rs +++ b/quinn-proto/src/tests/mod.rs @@ -2994,6 +2994,32 @@ fn reject_manually() { )); } +#[test] +fn validate_then_reject_manually() { + let _guard = subscribe(); + let mut pair = Pair::default(); + pair.server.incoming_connection_behavior = IncomingConnectionBehavior::ValidateThenReject; + + // The server should now retry and reject incoming connections. + let client_ch = pair.begin_connect(client_config()); + pair.drive(); + pair.server.assert_no_accept(); + let client = pair.client.connections.get_mut(&client_ch).unwrap(); + assert!(client.is_closed()); + assert!(matches!( + client.poll(), + Some(Event::ConnectionLost { + reason: ConnectionError::ConnectionClosed(close) + }) if close.error_code == TransportErrorCode::CONNECTION_REFUSED + )); + pair.drive(); + assert_matches!(pair.client_conn_mut(client_ch).poll(), None); + assert_eq!(pair.client.known_connections(), 0); + assert_eq!(pair.client.known_cids(), 0); + assert_eq!(pair.server.known_connections(), 0); + assert_eq!(pair.server.known_cids(), 0); +} + #[test] fn endpoint_and_connection_impl_send_sync() { const fn is_send_sync() {} diff --git a/quinn-proto/src/tests/util.rs b/quinn-proto/src/tests/util.rs index 3780e7225..4b4625f7b 100644 --- a/quinn-proto/src/tests/util.rs +++ b/quinn-proto/src/tests/util.rs @@ -306,6 +306,7 @@ pub(super) enum IncomingConnectionBehavior { AcceptAll, RejectAll, Validate, + ValidateThenReject, Wait, } @@ -377,6 +378,13 @@ impl TestEndpoint { self.retry(incoming); } } + IncomingConnectionBehavior::ValidateThenReject => { + if incoming.remote_address_validated() { + self.reject(incoming); + } else { + self.retry(incoming); + } + } IncomingConnectionBehavior::Wait => { self.waiting_incoming.push(incoming); }