diff --git a/src/connection/connection.rs b/src/connection/connection.rs index 10459c4d..cc472335 100644 --- a/src/connection/connection.rs +++ b/src/connection/connection.rs @@ -4746,6 +4746,36 @@ pub(crate) mod tests { Ok(()) } + #[test] + fn handshake_with_antiamplification_deadlock() -> Result<()> { + let mut test_pair = TestPair::new_with_test_config()?; + + // Client send Initial. + let packets = TestPair::conn_packets_out(&mut test_pair.client)?; + TestPair::conn_packets_in(&mut test_pair.server, packets)?; + + // Server send Initial and Handshake. + let mut packets = TestPair::conn_packets_out(&mut test_pair.server)?; + + // Fake dropping the second packet. + packets.truncate(1); + + // Client recv Initial and the first Handshake. + TestPair::conn_packets_in(&mut test_pair.client, packets)?; + assert!(!test_pair.client.tls_session.is_completed()); + + // Client send ACK and PADDING and wait for retransmission of the second packet. + let _ = TestPair::conn_packets_out(&mut test_pair.client)?; + + // `LossDetection` timer should not be None to avoid deadlock. + assert!(test_pair.client.timeout().is_some()); + assert!(test_pair.client.timers.get(Timer::LossDetection).is_some()); + + // TODO: complete the remaining part after supporting anti-amplification in server side. + + Ok(()) + } + #[test] fn handshake_with_alpn_mismatched() -> Result<()> { let mut client_config = TestPair::new_test_config(false)?; diff --git a/src/connection/recovery.rs b/src/connection/recovery.rs index 20216604..6128f486 100644 --- a/src/connection/recovery.rs +++ b/src/connection/recovery.rs @@ -82,6 +82,9 @@ pub struct Recovery { /// declared lost. The size does not include IP or UDP overhead. pub bytes_in_flight: usize, + /// Number of ack-eliciting packets in flight. + pub ack_eliciting_in_flight: u64, + /// RTT estimation for the corresponding path. pub rtt: RttEstimator, @@ -104,6 +107,7 @@ impl Recovery { pkt_thresh: INITIAL_PACKET_THRESHOLD, time_thresh: INITIAL_TIME_THRESHOLD, bytes_in_flight: 0, + ack_eliciting_in_flight: 0, rtt: RttEstimator::new(conf.initial_rtt), congestion: congestion_control::build_congestion_controller(conf), trace_id: String::from(""), @@ -167,6 +171,8 @@ impl Recovery { if ack_eliciting { space.time_of_last_sent_ack_eliciting_pkt = Some(now); space.loss_probes = space.loss_probes.saturating_sub(1); + space.ack_eliciting_in_flight += 1; + self.ack_eliciting_in_flight += 1; } space.bytes_in_flight += sent_size; @@ -295,6 +301,13 @@ impl Recovery { space.bytes_in_flight = space.bytes_in_flight.saturating_sub(sent_pkt.sent_size); self.bytes_in_flight = self.bytes_in_flight.saturating_sub(sent_pkt.sent_size); + + if sent_pkt.ack_eliciting { + space.ack_eliciting_in_flight = + space.ack_eliciting_in_flight.saturating_sub(1); + self.ack_eliciting_in_flight = + self.ack_eliciting_in_flight.saturating_sub(1); + } } // Process each acked packet in congestion controller and update delivery @@ -397,6 +410,13 @@ impl Recovery { lost_bytes += unacked.sent_size; space.bytes_in_flight = space.bytes_in_flight.saturating_sub(unacked.sent_size); self.bytes_in_flight = self.bytes_in_flight.saturating_sub(unacked.sent_size); + + if unacked.ack_eliciting { + space.ack_eliciting_in_flight = + space.ack_eliciting_in_flight.saturating_sub(1); + self.ack_eliciting_in_flight = + self.ack_eliciting_in_flight.saturating_sub(1); + } } latest_lost_packet = Some(unacked.clone()); trace!( @@ -484,7 +504,7 @@ impl Recovery { // TODO: The server's timer is not set if nothing can be sent. - if self.bytes_in_flight == 0 && handshake_status.peer_verified_address { + if self.ack_eliciting_in_flight == 0 && handshake_status.peer_verified_address { // There is nothing to detect lost, so no timer is set. // However, the client needs to arm the timer if the // server might be blocked by the anti-amplification limit. @@ -530,7 +550,7 @@ impl Recovery { } // PTO timer mode (REVISIT) - let sid = if self.bytes_in_flight > 0 { + let sid = if self.ack_eliciting_in_flight > 0 { // Send new data if available, else retransmit old data. If neither // is available, send a single PING frame. let (_, e) = self.get_pto_time_and_space(space_id, spaces, handshake_status, now); @@ -642,8 +662,8 @@ impl Recovery { ) -> (Option, SpaceId) { let mut duration = self.calculate_pto(); - // Arm PTO from now when there are no inflight packets. - if self.bytes_in_flight == 0 { + // Arm PTO from now when there are no ack-eliciting packets inflight. + if self.ack_eliciting_in_flight == 0 { if handshake_status.derived_handshake_keys { return (Some(now + duration), SpaceId::Handshake); } else { @@ -665,7 +685,7 @@ impl Recovery { Some(space) => space, None => continue, }; - if space.bytes_in_flight == 0 { + if space.ack_eliciting_in_flight == 0 { continue; } @@ -720,6 +740,7 @@ impl Recovery { space.loss_time = None; space.loss_probes = 0; space.bytes_in_flight = 0; + space.ack_eliciting_in_flight = 0; self.set_loss_detection_timer(space_id, spaces, handshake_status, now); } @@ -728,12 +749,16 @@ impl Recovery { /// When Initial or Handshake keys are discarded, packets sent in that /// space no longer count toward bytes in flight. fn remove_from_bytes_in_flight(&mut self, space: &PacketNumSpace) { - let unacked_bytes = space - .sent - .iter() - .filter(|p| p.in_flight && p.time_acked.is_none() && p.time_lost.is_none()) - .fold(0, |acc, p| acc + p.sent_size); - self.bytes_in_flight = self.bytes_in_flight.saturating_sub(unacked_bytes); + for pkt in &space.sent { + if !pkt.in_flight || pkt.time_acked.is_some() || pkt.time_lost.is_some() { + continue; + } + + self.bytes_in_flight = self.bytes_in_flight.saturating_sub(pkt.sent_size); + if pkt.ack_eliciting { + self.ack_eliciting_in_flight = self.ack_eliciting_in_flight.saturating_sub(1); + } + } } /// Update maximum datagram size @@ -830,7 +855,9 @@ mod tests { recovery.on_packet_sent(sent_pkt2, space_id, &mut spaces, status, now); assert_eq!(spaces.get(space_id).unwrap().sent.len(), 3); assert_eq!(spaces.get(space_id).unwrap().bytes_in_flight, 3003); + assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 3); assert_eq!(recovery.bytes_in_flight, 3003); + assert_eq!(recovery.ack_eliciting_in_flight, 3); // Advance ticks and fake receiving of ack now += Duration::from_millis(100); @@ -839,6 +866,8 @@ mod tests { acked.insert(2..3); recovery.on_ack_received(&acked, 0, SpaceId::Handshake, &mut spaces, status, now)?; assert_eq!(spaces.get(space_id).unwrap().sent.len(), 2); + assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 1); + assert_eq!(recovery.ack_eliciting_in_flight, 1); // Advance ticks until loss timeout now = recovery.loss_detection_timer().unwrap(); @@ -846,6 +875,8 @@ mod tests { recovery.on_loss_detection_timeout(SpaceId::Handshake, &mut spaces, status, now); assert_eq!(lost_pkts, 1); assert_eq!(lost_bytes, 1001); + assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 0); + assert_eq!(recovery.ack_eliciting_in_flight, 0); Ok(()) } @@ -1026,7 +1057,9 @@ mod tests { recovery.on_pkt_num_space_discarded(space_id, &mut spaces, status, now); assert_eq!(spaces.get(space_id).unwrap().sent.len(), 0); assert_eq!(spaces.get(space_id).unwrap().bytes_in_flight, 0); + assert_eq!(spaces.get(space_id).unwrap().ack_eliciting_in_flight, 0); assert_eq!(recovery.bytes_in_flight, 1003); + assert_eq!(recovery.ack_eliciting_in_flight, 1); Ok(()) } diff --git a/src/connection/space.rs b/src/connection/space.rs index c998c499..dd6211d6 100644 --- a/src/connection/space.rs +++ b/src/connection/space.rs @@ -118,6 +118,9 @@ pub struct PacketNumSpace { /// number space. pub bytes_in_flight: usize, + /// Number of ack-eliciting packets in flight. + pub ack_eliciting_in_flight: u64, + /// Packet number space for application data pub is_data: bool, @@ -146,6 +149,7 @@ impl PacketNumSpace { largest_acked_pkt: std::u64::MAX, loss_probes: 0, bytes_in_flight: 0, + ack_eliciting_in_flight: 0, is_data: id != SpaceId::Initial && id != SpaceId::Handshake, reinject: ReinjectQueue::default(), }