diff --git a/examples/http3_client.py b/examples/http3_client.py index fe55d22db..1132db55b 100644 --- a/examples/http3_client.py +++ b/examples/http3_client.py @@ -26,6 +26,8 @@ from aioquic.quic.events import QuicEvent from aioquic.quic.logger import QuicFileLogger from aioquic.tls import CipherSuite, SessionTicket +from aioquic.quic.congestion.reno import RenoCongestionControl +from aioquic.quic.congestion.cubic import CubicCongestionControl try: import uvloop @@ -512,6 +514,9 @@ async def main( parser.add_argument( "--zero-rtt", action="store_true", help="try to send requests using 0-RTT" ) + parser.add_argument( + "--congestion-control", type=str, help="which congestion control algorithm to use (reno, cubic)" + ) args = parser.parse_args() @@ -545,6 +550,13 @@ async def main( configuration.quic_logger = QuicFileLogger(args.quic_log) if args.secrets_log: configuration.secrets_log_file = open(args.secrets_log, "a") + if args.congestion_control: + if args.congestion_control == "cubic": + configuration.congestion_control = CubicCongestionControl() + elif args.congestion_control == "reno": + configuration.congestion_control = RenoCongestionControl() + else: + raise Exception("Invalid congestion control algorithm") if args.session_ticket: try: with open(args.session_ticket, "rb") as fp: diff --git a/examples/http3_server.py b/examples/http3_server.py index 522d25e67..a244fb5f4 100644 --- a/examples/http3_server.py +++ b/examples/http3_server.py @@ -25,6 +25,8 @@ from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent from aioquic.quic.logger import QuicFileLogger from aioquic.tls import SessionTicket +from aioquic.quic.congestion.reno import RenoCongestionControl +from aioquic.quic.congestion.cubic import CubicCongestionControl try: import uvloop @@ -556,6 +558,9 @@ async def main( parser.add_argument( "-v", "--verbose", action="store_true", help="increase logging verbosity" ) + parser.add_argument( + "--congestion-control", type=str, help="which congestion control algorithm to use (reno, cubic)" + ) args = parser.parse_args() logging.basicConfig( @@ -592,6 +597,14 @@ async def main( # load SSL certificate and key configuration.load_cert_chain(args.certificate, args.private_key) + if args.congestion_control: + if args.congestion_control == "cubic": + configuration.congestion_control = CubicCongestionControl() + elif args.congestion_control == "reno": + configuration.congestion_control = RenoCongestionControl() + else: + raise Exception("Invalid congestion control algorithm") + if uvloop is not None: uvloop.install() diff --git a/src/aioquic/quic/configuration.py b/src/aioquic/quic/configuration.py index d67d7fe2f..5f1f61a6b 100644 --- a/src/aioquic/quic/configuration.py +++ b/src/aioquic/quic/configuration.py @@ -2,6 +2,7 @@ from os import PathLike from re import split from typing import Any, List, Optional, TextIO, Union +from .recovery import QuicCongestionControl from ..tls import ( CipherSuite, @@ -82,6 +83,11 @@ class QuicConfiguration: The TLS session ticket which should be used for session resumption. """ + congestion_control: Optional[QuicCongestionControl] = None + """ + Selection for a congestion control algorithm + """ + cadata: Optional[bytes] = None cafile: Optional[str] = None capath: Optional[str] = None diff --git a/src/aioquic/quic/congestion/__init__.py b/src/aioquic/quic/congestion/__init__.py new file mode 100644 index 000000000..336082d70 --- /dev/null +++ b/src/aioquic/quic/congestion/__init__.py @@ -0,0 +1,145 @@ +from ..packet_builder import QuicSentPacket +from typing import Iterable, Optional, Dict, Any +from datetime import datetime +from enum import Enum + +K_GRANULARITY = 0.001 # seconds + +# congestion control +K_INITIAL_WINDOW = 10 +K_MINIMUM_WINDOW = 2 +K_LOSS_REDUCTION_FACTOR = 0.5 + +class CongestionEvent(Enum): + ACK=0 + PACKET_SENT=1 + PACKET_EXPIRED=2 + PACKET_LOST=3 + RTT_MEASURED=4 + +class QuicRttMonitor: + """ + Roundtrip time monitor for HyStart. + """ + + def __init__(self) -> None: + self._increases = 0 + self._last_time = None + self._ready = False + self._size = 5 + + self._filtered_min: Optional[float] = None + + self._sample_idx = 0 + self._sample_max: Optional[float] = None + self._sample_min: Optional[float] = None + self._sample_time = 0.0 + self._samples = [0.0 for i in range(self._size)] + + def add_rtt(self, rtt: float) -> None: + self._samples[self._sample_idx] = rtt + self._sample_idx += 1 + + if self._sample_idx >= self._size: + self._sample_idx = 0 + self._ready = True + + if self._ready: + self._sample_max = self._samples[0] + self._sample_min = self._samples[0] + for sample in self._samples[1:]: + if sample < self._sample_min: + self._sample_min = sample + elif sample > self._sample_max: + self._sample_max = sample + + def is_rtt_increasing(self, rtt: float, now: float) -> bool: + if now > self._sample_time + K_GRANULARITY: + self.add_rtt(rtt) + self._sample_time = now + + if self._ready: + if self._filtered_min is None or self._filtered_min > self._sample_max: + self._filtered_min = self._sample_max + + delta = self._sample_min - self._filtered_min + if delta * 4 >= self._filtered_min: + self._increases += 1 + if self._increases >= self._size: + return True + elif delta > 0: + self._increases = 0 + return False + +class QuicCongestionControl: + + def __init__(self, *, max_datagram_size : int, callback=None, fixed_cwnd = 10*1024*1024) -> None: + self.callback = callback # a callback argument that is called when an event occurs + # 10 GB window or custom fixed size window (shouldn't be used in real network !, use a real CCA instead) + self._max_datagram_size = max_datagram_size + self.cwnd = fixed_cwnd + self.data_in_flight = 0 + + def set_recovery(self, recovery): + # recovery is a QuicPacketRecovery instance + self.recovery = recovery + + def on_packet_acked(self, packet: QuicSentPacket, now : float): + if self.callback: + self.callback(CongestionEvent.ACK, self) + if type(self) == QuicCongestionControl: + # don't call this if it is a superclass that runs + self.data_in_flight -= packet.sent_bytes + + def on_packet_sent(self, packet: QuicSentPacket, now : float) -> None: + if self.callback: + self.callback(CongestionEvent.PACKET_SENT, self) + if type(self) == QuicCongestionControl: + # don't call this if it is a superclass that runs + self.data_in_flight += packet.sent_bytes + + def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: + if self.callback: + self.callback(CongestionEvent.PACKET_EXPIRED, self) + if type(self) == QuicCongestionControl: + for packet in packets: + # don't call this if it is a superclass that runs + self.data_in_flight -= packet.sent_bytes + + def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: + if self.callback: + self.callback(CongestionEvent.PACKET_LOST, self) + if type(self) == QuicCongestionControl: + for packet in packets: + # don't call this if it is a superclass that runs + self.data_in_flight -= packet.sent_bytes + + def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: + if self.callback: + self.callback(CongestionEvent.RTT_MEASURED, self) + + def get_congestion_window(self) -> int: + # return the cwnd in number of bytes + return self.cwnd + + def _set_congestion_window(self, value): + self.cwnd = value + + def get_ssthresh(self) -> Optional[int]: + pass + + def get_bytes_in_flight(self) -> int: + return self.data_in_flight + + def log_callback(self) -> Dict[str, Any]: + # a callback called when a recovery happens + # The data object will be saved in the log file, so feel free to add + # any attribute you want to track + data: Dict[str, Any] = { + "bytes_in_flight": self.get_bytes_in_flight(), + "cwnd": self.get_congestion_window(), + } + if self.get_ssthresh() is not None: + data["ssthresh"] = self.get_ssthresh() + + return data \ No newline at end of file diff --git a/src/aioquic/quic/congestion/cubic.py b/src/aioquic/quic/congestion/cubic.py new file mode 100644 index 000000000..b7758b1da --- /dev/null +++ b/src/aioquic/quic/congestion/cubic.py @@ -0,0 +1,222 @@ +from . import QuicCongestionControl, QuicRttMonitor, K_INITIAL_WINDOW, K_MINIMUM_WINDOW +from ..packet_builder import QuicSentPacket +from typing import Iterable, Optional, Dict, Any + +# cubic specific variables (see https://www.rfc-editor.org/rfc/rfc9438.html#name-definitions) +K_CUBIC_K = 1 +K_CUBIC_C = 0.4 +K_CUBIC_LOSS_REDUCTION_FACTOR = 0.7 +K_CUBIC_MAX_IDLE_TIME = 2 # reset the cwnd after 2 seconds of inactivity + +class CubicCongestionControl(QuicCongestionControl): + """ + Cubic congestion control implementation for aioquic + """ + + def __init__(self, max_datagram_size : int, callback=None, reno_friendly_activated = True) -> None: + super().__init__(max_datagram_size=max_datagram_size, callback=callback) + self.additive_increase_factor = max_datagram_size # increase by one segment + + self._max_datagram_size = max_datagram_size + self.bytes_in_flight = 0 + self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size + self._congestion_recovery_start_time = 0.0 + self.ssthresh: Optional[int] = None + self.reno_friendly_activated = reno_friendly_activated + + self._rtt_monitor = QuicRttMonitor() + + self.reset() + + self.last_ack = None + + self.now = 0 + + def better_cube_root(self, x): + if (x < 0): + # avoid precision errors that make the cube root returns an imaginary number + return -((-x)**(1./3.)) + else: + return (x)**(1./3.) + + def W_cubic(self, t): + W_max_segments = self._W_max / self._max_datagram_size + target_segments = K_CUBIC_C * (t - self.K)**3 + (W_max_segments) + return target_segments * self._max_datagram_size + + def is_reno_friendly(self, t) -> bool: + return self.reno_friendly_activated and self.W_cubic(t) < self._W_est + + def is_concave(self): + return self.congestion_window < self._W_max + + def is_convex(self): + return self.congestion_window >= self._W_max + + def reset(self): + self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size + + self._cwnd_prior = None + self._cwnd_epoch = None + self._t_epoch = None + self._W_max = None + self._first_slow_start = True + self._starting_congestion_avoidance = False + self.K = 0 + self._W_est = 0 + self._cwnd_epoch = 0 + self._t_epoch = 0 + self._W_max = self.congestion_window + + def on_packet_acked(self, packet: QuicSentPacket, now: float) -> None: + super().on_packet_acked(packet, now) + rtt = self.recovery._rtt_smoothed + self.bytes_in_flight -= packet.sent_bytes + self.last_ack = now + self.now = now + + if self.ssthresh is None or self.congestion_window < self.ssthresh: + # slow start + self.congestion_window += packet.sent_bytes + else: + # congestion avoidance + if (self._first_slow_start and not self._starting_congestion_avoidance): + # exiting slow start without having a loss + self._first_slow_start = False + self._cwnd_prior = self.congestion_window + self._W_max = self.congestion_window + self._t_epoch = now + self._cwnd_epoch = self.congestion_window + self._W_est = self._cwnd_epoch + # calculate K + W_max_segments = self._W_max / self._max_datagram_size + cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size + self.K = self.better_cube_root((W_max_segments - cwnd_epoch_segments)/K_CUBIC_C) + + # initialize the variables used at start of congestion avoidance + if self._starting_congestion_avoidance: + self._starting_congestion_avoidance = False + self._first_slow_start = False + self._t_epoch = now + self._cwnd_epoch = self.congestion_window + self._W_est = self._cwnd_epoch + # calculate K + W_max_segments = self._W_max / self._max_datagram_size + cwnd_epoch_segments = self._cwnd_epoch / self._max_datagram_size + self.K = self.better_cube_root((W_max_segments - cwnd_epoch_segments)/K_CUBIC_C) + + + self._W_est = self._W_est + self.additive_increase_factor*(packet.sent_bytes/self.congestion_window) + + t = now - self._t_epoch + + target = None + if (self.W_cubic(t + rtt) < self.congestion_window): + target = self.congestion_window + elif (self.W_cubic(t + rtt) > 1.5*self.congestion_window): + target = self.congestion_window*1.5 + else: + target = self.W_cubic(t + rtt) + + + if self.is_reno_friendly(t): + # reno friendly region of cubic (https://www.rfc-editor.org/rfc/rfc9438.html#name-reno-friendly-region) + self.congestion_window = self._W_est + elif self.is_concave(): + # concave region of cubic (https://www.rfc-editor.org/rfc/rfc9438.html#name-concave-region) + self.congestion_window = self.congestion_window + ((target - self.congestion_window)*(self._max_datagram_size/self.congestion_window)) + else: + # convex region of cubic (https://www.rfc-editor.org/rfc/rfc9438.html#name-convex-region) + self.congestion_window = self.congestion_window + ((target - self.congestion_window)*(self._max_datagram_size/self.congestion_window)) + + def on_packet_sent(self, packet: QuicSentPacket, now : float) -> None: + super().on_packet_sent(packet, now) + self.bytes_in_flight += packet.sent_bytes + if self.last_ack == None: + return + elapsed_idle = now - self.last_ack + if (elapsed_idle >= K_CUBIC_MAX_IDLE_TIME): + self.reset() + + def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: + super().on_packets_expired(packets) + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + + def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: + self.now = now + super().on_packets_lost(packets, now) + lost_largest_time = 0.0 + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + lost_largest_time = packet.sent_time + + # start a new congestion event if packet was sent after the + # start of the previous congestion recovery period. + if lost_largest_time > self._congestion_recovery_start_time: + + self._congestion_recovery_start_time = now + + # Normal congestion handle, can't be used in same time as fast convergence + # self._W_max = self.congestion_window + + # fast convergence + if (self._W_max != None and self.congestion_window < self._W_max): + self._W_max = self.congestion_window * (1 + K_CUBIC_LOSS_REDUCTION_FACTOR) / 2 + else: + self._W_max = self.congestion_window + + # normal congestion MD + flight_size = self.bytes_in_flight + new_ssthresh = max(int(flight_size*K_CUBIC_LOSS_REDUCTION_FACTOR), K_MINIMUM_WINDOW * self._max_datagram_size) + self.ssthresh = new_ssthresh + self._cwnd_prior = self.congestion_window + self.congestion_window = max(self.ssthresh, K_MINIMUM_WINDOW * self._max_datagram_size) + + + self._starting_congestion_avoidance = True # restart a new congestion avoidance phase + + + def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: + self.now = now + super().on_rtt_measurement(latest_rtt, now) + # check whether we should exit slow start + if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( + latest_rtt, now + ): + self._cwnd_prior = self.congestion_window + + + def get_congestion_window(self) -> int: + return int(self.congestion_window) + + def _set_congestion_window(self, value): + self.congestion_window = value + + def get_ssthresh(self) -> int: + return self.ssthresh + + def get_bytes_in_flight(self) -> int: + return self.bytes_in_flight + + def log_callback(self) -> Dict[str, Any]: + data = super().log_callback() + + if self._W_max == None: + data["W_max"] = None + else: + data["W_max"] = int(self._W_max) + + + if self.ssthresh != None: + t = self.now - self._t_epoch + + # saving the phase + if (self.ssthresh == None): + data["Phase"] = "slow-start" + elif (self.is_reno_friendly(t)): + data["Phase"] = "reno-friendly region" + else: + data["Phase"] = "cubic-growth" + + return data \ No newline at end of file diff --git a/src/aioquic/quic/congestion/reno.py b/src/aioquic/quic/congestion/reno.py new file mode 100644 index 000000000..ab4d49ffc --- /dev/null +++ b/src/aioquic/quic/congestion/reno.py @@ -0,0 +1,102 @@ +from . import QuicCongestionControl, QuicRttMonitor, K_LOSS_REDUCTION_FACTOR, K_INITIAL_WINDOW, K_MINIMUM_WINDOW +from ..packet_builder import QuicSentPacket +from typing import Iterable, Dict, Any, Optional + + +class RenoCongestionControl(QuicCongestionControl): + """ + New Reno congestion control. + """ + + def __init__(self, * , max_datagram_size: int, callback=None) -> None: + super().__init__(max_datagram_size=max_datagram_size, callback=callback) + self._max_datagram_size = max_datagram_size + self.bytes_in_flight = 0 + self.congestion_window = K_INITIAL_WINDOW * max_datagram_size + self._congestion_recovery_start_time = 0.0 + self._congestion_stash = 0 + self.ssthresh: Optional[int] = None + self._rtt_monitor = QuicRttMonitor() + + + def on_packet_acked(self, packet: QuicSentPacket, now : float) -> None: + super().on_packet_acked(packet, now) + self.bytes_in_flight -= packet.sent_bytes + + # don't increase window in congestion recovery + if packet.sent_time <= self._congestion_recovery_start_time: + return + + if self.ssthresh is None or self.congestion_window < self.ssthresh: + # slow start + self.congestion_window += packet.sent_bytes + else: + # congestion avoidance + self._congestion_stash += packet.sent_bytes + count = self._congestion_stash // self.congestion_window + if count: + self._congestion_stash -= count * self.congestion_window + self.congestion_window += count * self._max_datagram_size + + def on_packet_sent(self, packet: QuicSentPacket, now : float) -> None: + super().on_packet_sent(packet, now) + self.bytes_in_flight += packet.sent_bytes + + def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: + super().on_packets_expired(packets) + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + + def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: + super().on_packets_lost(packets, now) + lost_largest_time = 0.0 + for packet in packets: + self.bytes_in_flight -= packet.sent_bytes + lost_largest_time = packet.sent_time + + if lost_largest_time > self._congestion_recovery_start_time: + self._congestion_recovery_start_time = now + self.congestion_window = max( + int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), + K_MINIMUM_WINDOW * self._max_datagram_size, + ) + self.ssthresh = self.congestion_window + + # start a new congestion event if packet was sent after the + # start of the previous congestion recovery period. + if lost_largest_time > self._congestion_recovery_start_time: + self._congestion_recovery_start_time = now + self.congestion_window = max( + int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), K_MINIMUM_WINDOW * self._max_datagram_size + ) + self.ssthresn = self.congestion_window + + # TODO : collapse congestion window if persistent congestion + + def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: + super().on_rtt_measurement(latest_rtt, now) + # check whether we should exit slow start + if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( + latest_rtt, now + ): + self.ssthresh = self.congestion_window + + def get_congestion_window(self) -> int: + return int(self.congestion_window) + + def _set_congestion_window(self, value): + self.congestion_window = value + + def get_ssthresh(self) -> int: + return self.ssthresh + + def get_bytes_in_flight(self) -> int: + return self.bytes_in_flight + + def log_callback(self) -> Dict[str, Any]: + data = super().log_callback() + if (self.ssthresh == None): + data["Phase"] = "slow-start" + else: + data["Phase"] = "congestion-avoidance" + return data \ No newline at end of file diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index 81d1bdb14..bb615e649 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -363,15 +363,28 @@ def __init__( odcid=self._original_destination_connection_id, ) + # loss recovery - self._loss = QuicPacketRecovery( - initial_rtt=configuration.initial_rtt, - max_datagram_size=self._max_datagram_size, - peer_completed_address_validation=not self._is_client, - quic_logger=self._quic_logger, - send_probe=self._send_probe, - logger=self._logger, - ) + if (configuration.congestion_control): + self._loss = QuicPacketRecovery( + initial_rtt=configuration.initial_rtt, + max_datagram_size=self._max_datagram_size, + peer_completed_address_validation=not self._is_client, + quic_logger=self._quic_logger, + send_probe=self._send_probe, + logger=self._logger, + congestion_control=configuration.congestion_control + ) + else: + + self._loss = QuicPacketRecovery( + initial_rtt=configuration.initial_rtt, + max_datagram_size=self._max_datagram_size, + peer_completed_address_validation=not self._is_client, + quic_logger=self._quic_logger, + send_probe=self._send_probe, + logger=self._logger, + ) # things to send self._close_pending = False @@ -577,7 +590,7 @@ def datagrams_to_send(self, now: float) -> List[Tuple[bytes, NetworkAddress]]: for packet in packets: packet.sent_time = now self._loss.on_packet_sent( - packet=packet, space=self._spaces[packet.epoch] + packet=packet, space=self._spaces[packet.epoch], now=now ) if packet.epoch == tls.Epoch.HANDSHAKE: sent_handshake = True @@ -3240,4 +3253,4 @@ def _write_streams_blocked_frame( is_unidirectional=frame_type == QuicFrameType.STREAMS_BLOCKED_UNI, limit=limit, ) - ) + ) \ No newline at end of file diff --git a/src/aioquic/quic/logger.py b/src/aioquic/quic/logger.py index 8deede679..429e0b319 100644 --- a/src/aioquic/quic/logger.py +++ b/src/aioquic/quic/logger.py @@ -32,7 +32,6 @@ def hexdump(data: bytes) -> str: return binascii.hexlify(data).decode("ascii") - class QuicLoggerTrace: """ A QUIC event trace. diff --git a/src/aioquic/quic/recovery.py b/src/aioquic/quic/recovery.py index 2a95c8590..d339dada3 100644 --- a/src/aioquic/quic/recovery.py +++ b/src/aioquic/quic/recovery.py @@ -1,10 +1,14 @@ import logging import math from typing import Any, Callable, Dict, Iterable, List, Optional +from datetime import datetime from .logger import QuicLoggerTrace from .packet_builder import QuicDeliveryState, QuicSentPacket from .rangeset import RangeSet +from .congestion import QuicCongestionControl +from .congestion.cubic import CubicCongestionControl +from .congestion.reno import RenoCongestionControl # loss detection K_PACKET_THRESHOLD = 3 @@ -12,11 +16,9 @@ K_TIME_THRESHOLD = 9 / 8 K_MICRO_SECOND = 0.000001 K_SECOND = 1.0 +K_MIN_RTT = 0.001 # 1ms + -# congestion control -K_INITIAL_WINDOW = 10 -K_MINIMUM_WINDOW = 2 -K_LOSS_REDUCTION_FACTOR = 0.5 class QuicPacketSpace: @@ -42,6 +44,7 @@ def __init__(self, *, max_datagram_size: int) -> None: self.bucket_time: float = 0.0 self.evaluation_time: float = 0.0 self.packet_time: Optional[float] = None + self.pacing_rate = None def next_send_time(self, now: float) -> float: if self.packet_time is not None: @@ -66,7 +69,10 @@ def update_bucket(self, now: float) -> None: self.evaluation_time = now def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None: - pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND) + if self.pacing_rate == None or self.pacing_rate == 0: + pacing_rate = congestion_window / max(smoothed_rtt, K_MICRO_SECOND) + else: + pacing_rate = self.pacing_rate self.packet_time = max( K_MICRO_SECOND, min(self._max_datagram_size / pacing_rate, K_SECOND) ) @@ -81,70 +87,8 @@ def update_rate(self, congestion_window: int, smoothed_rtt: float) -> None: if self.bucket_time > self.bucket_max: self.bucket_time = self.bucket_max - -class QuicCongestionControl: - """ - New Reno congestion control. - """ - - def __init__(self, *, max_datagram_size: int) -> None: - self._max_datagram_size = max_datagram_size - self.bytes_in_flight = 0 - self.congestion_window = K_INITIAL_WINDOW * self._max_datagram_size - self._congestion_recovery_start_time = 0.0 - self._congestion_stash = 0 - self._rtt_monitor = QuicRttMonitor() - self.ssthresh: Optional[int] = None - - def on_packet_acked(self, packet: QuicSentPacket) -> None: - self.bytes_in_flight -= packet.sent_bytes - - # don't increase window in congestion recovery - if packet.sent_time <= self._congestion_recovery_start_time: - return - - if self.ssthresh is None or self.congestion_window < self.ssthresh: - # slow start - self.congestion_window += packet.sent_bytes - else: - # congestion avoidance - self._congestion_stash += packet.sent_bytes - count = self._congestion_stash // self.congestion_window - if count: - self._congestion_stash -= count * self.congestion_window - self.congestion_window += count * self._max_datagram_size - - def on_packet_sent(self, packet: QuicSentPacket) -> None: - self.bytes_in_flight += packet.sent_bytes - - def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None: - for packet in packets: - self.bytes_in_flight -= packet.sent_bytes - - def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None: - lost_largest_time = 0.0 - for packet in packets: - self.bytes_in_flight -= packet.sent_bytes - lost_largest_time = packet.sent_time - - # start a new congestion event if packet was sent after the - # start of the previous congestion recovery period. - if lost_largest_time > self._congestion_recovery_start_time: - self._congestion_recovery_start_time = now - self.congestion_window = max( - int(self.congestion_window * K_LOSS_REDUCTION_FACTOR), - K_MINIMUM_WINDOW * self._max_datagram_size, - ) - self.ssthresh = self.congestion_window - - # TODO : collapse congestion window if persistent congestion - - def on_rtt_measurement(self, latest_rtt: float, now: float) -> None: - # check whether we should exit slow start - if self.ssthresh is None and self._rtt_monitor.is_rtt_increasing( - latest_rtt, now - ): - self.ssthresh = self.congestion_window + def set_pacing_rate(self, rate): + self.pacing_rate = rate class QuicPacketRecovery: @@ -160,6 +104,7 @@ def __init__( send_probe: Callable[[], None], logger: Optional[logging.LoggerAdapter] = None, quic_logger: Optional[QuicLoggerTrace] = None, + congestion_control: QuicCongestionControl = RenoCongestionControl(max_datagram_size=1280) ) -> None: self.max_ack_delay = 0.025 self.peer_completed_address_validation = peer_completed_address_validation @@ -181,16 +126,20 @@ def __init__( self._time_of_last_sent_ack_eliciting_packet = 0.0 # congestion control - self._cc = QuicCongestionControl(max_datagram_size=max_datagram_size) + + self._cc = congestion_control + self._cc._max_datagram_size = max_datagram_size + self._cc.set_recovery(self) + self._pacer = QuicPacketPacer(max_datagram_size=max_datagram_size) @property def bytes_in_flight(self) -> int: - return self._cc.bytes_in_flight + return self._cc.get_bytes_in_flight() @property def congestion_window(self) -> int: - return self._cc.congestion_window + return self._cc.get_congestion_window() def discard_space(self, space: QuicPacketSpace) -> None: assert space in self.spaces @@ -263,7 +212,7 @@ def on_ack_received( is_ack_eliciting = True space.ack_eliciting_in_flight -= 1 if packet.in_flight: - self._cc.on_packet_acked(packet) + self._cc.on_packet_acked(packet, now=now) largest_newly_acked = packet_number largest_sent_time = packet.sent_time @@ -304,7 +253,7 @@ def on_ack_received( # inform congestion controller self._cc.on_rtt_measurement(latest_rtt, now=now) self._pacer.update_rate( - congestion_window=self._cc.congestion_window, + congestion_window=self._cc.get_congestion_window(), smoothed_rtt=self._rtt_smoothed, ) @@ -327,7 +276,7 @@ def on_loss_detection_timeout(self, now: float) -> None: self._pto_count += 1 self.reschedule_data(now=now) - def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None: + def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace, now : float) -> None: space.sent_packets[packet.packet_number] = packet if packet.is_ack_eliciting: @@ -337,7 +286,7 @@ def on_packet_sent(self, packet: QuicSentPacket, space: QuicPacketSpace) -> None self._time_of_last_sent_ack_eliciting_packet = packet.sent_time # add packet to bytes in flight - self._cc.on_packet_sent(packet) + self._cc.on_packet_sent(packet, now=now) if self._quic_logger is not None: self._log_metrics_updated() @@ -398,12 +347,7 @@ def _get_loss_space(self) -> Optional[QuicPacketSpace]: return loss_space def _log_metrics_updated(self, log_rtt=False) -> None: - data: Dict[str, Any] = { - "bytes_in_flight": self._cc.bytes_in_flight, - "cwnd": self._cc.congestion_window, - } - if self._cc.ssthresh is not None: - data["ssthresh"] = self._cc.ssthresh + data: Dict[str, Any] = self._cc.log_callback() if log_rtt: data.update( @@ -451,63 +395,8 @@ def _on_packets_lost( if lost_packets_cc: self._cc.on_packets_lost(lost_packets_cc, now=now) self._pacer.update_rate( - congestion_window=self._cc.congestion_window, + congestion_window=self._cc.get_congestion_window(), smoothed_rtt=self._rtt_smoothed, ) if self._quic_logger is not None: - self._log_metrics_updated() - - -class QuicRttMonitor: - """ - Roundtrip time monitor for HyStart. - """ - - def __init__(self) -> None: - self._increases = 0 - self._last_time = None - self._ready = False - self._size = 5 - - self._filtered_min: Optional[float] = None - - self._sample_idx = 0 - self._sample_max: Optional[float] = None - self._sample_min: Optional[float] = None - self._sample_time = 0.0 - self._samples = [0.0 for i in range(self._size)] - - def add_rtt(self, rtt: float) -> None: - self._samples[self._sample_idx] = rtt - self._sample_idx += 1 - - if self._sample_idx >= self._size: - self._sample_idx = 0 - self._ready = True - - if self._ready: - self._sample_max = self._samples[0] - self._sample_min = self._samples[0] - for sample in self._samples[1:]: - if sample < self._sample_min: - self._sample_min = sample - elif sample > self._sample_max: - self._sample_max = sample - - def is_rtt_increasing(self, rtt: float, now: float) -> bool: - if now > self._sample_time + K_GRANULARITY: - self.add_rtt(rtt) - self._sample_time = now - - if self._ready: - if self._filtered_min is None or self._filtered_min > self._sample_max: - self._filtered_min = self._sample_max - - delta = self._sample_min - self._filtered_min - if delta * 4 >= self._filtered_min: - self._increases += 1 - if self._increases >= self._size: - return True - elif delta > 0: - self._increases = 0 - return False + self._log_metrics_updated() \ No newline at end of file diff --git a/tests/test_cubic.py b/tests/test_cubic.py new file mode 100644 index 000000000..39750c8a5 --- /dev/null +++ b/tests/test_cubic.py @@ -0,0 +1,58 @@ +from aioquic.quic.congestion.cubic import CubicCongestionControl, K_CUBIC_C, K_CUBIC_LOSS_REDUCTION_FACTOR, QuicSentPacket +from aioquic.quic.congestion import K_MAX_DATAGRAM_SIZE +import unittest + +def W_cubic(t, K, W_max): + return K_CUBIC_C * (t - K)**3 + (W_max) + +def cube_root(x): + if (x < 0): return -((-x)**(1/3)) + else: return x**(1/3) + +class CubicTests(unittest.TestCase): + + def test_congestion_avoidance(self): + """ + Check if the cubic implementation respects the mathematical formula defined in the rfc 9438 + """ + + n = 400 # number of ms to check + + W_max = 5 # starting W_max + K = cube_root(W_max*(1-K_CUBIC_LOSS_REDUCTION_FACTOR)/K_CUBIC_C) + cwnd = W_max*K_CUBIC_LOSS_REDUCTION_FACTOR + + correct = [] + + test_range = range(n) + + for i in test_range: + correct.append(W_cubic(i/1000, K, W_max) * K_MAX_DATAGRAM_SIZE) + + cubic = CubicCongestionControl() + cubic._W_max = W_max * K_MAX_DATAGRAM_SIZE + cubic._starting_congestion_avoidance = True + cubic.congestion_window = cwnd * K_MAX_DATAGRAM_SIZE + cubic.slow_start.ssthresh = cubic.congestion_window + cubic._W_est = 0 + + results = [] + for i in test_range: + cwnd = cubic.congestion_window // K_MAX_DATAGRAM_SIZE # number of segments + + # simulate the reception of cwnd packets (a full window of acks) + for _ in range(int(cwnd)): + packet = QuicSentPacket(None, True, True, True, 0, 0) + packet.sent_bytes = 0 # won't affect results + rtt = 0 + cubic.on_packet_acked_timed(packet, i/1000, rtt) + + results.append(cubic.congestion_window) + + for i in test_range: + # check if it is almost equal to the value of W_cubic + self.assertTrue(correct[i]*0.99 <= results[i] <= 1.01*correct[i], F"Error at {i}ms, Result={results[i]}, Expected={correct[i]}") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_recovery.py b/tests/test_recovery.py index 086016e59..3c67ba79b 100644 --- a/tests/test_recovery.py +++ b/tests/test_recovery.py @@ -8,9 +8,9 @@ from aioquic.quic.recovery import ( QuicPacketPacer, QuicPacketRecovery, - QuicPacketSpace, - QuicRttMonitor, + QuicPacketSpace ) +from aioquic.quic.congestion import QuicRttMonitor def send_probe():