Skip to content

Commit

Permalink
Added cubic congestion control
Browse files Browse the repository at this point in the history
  • Loading branch information
Aperence committed Dec 12, 2023
1 parent 2ae1ad4 commit 4cf275a
Show file tree
Hide file tree
Showing 11 changed files with 611 additions and 152 deletions.
12 changes: 12 additions & 0 deletions examples/http3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from aioquic.quic.events import QuicEvent
from aioquic.quic.logger import QuicFileLogger
from aioquic.tls import CipherSuite, SessionTicket
from aioquic.quic.congestion.reno import RenoCongestionControl
from aioquic.quic.congestion.cubic import CubicCongestionControl

try:
import uvloop
Expand Down Expand Up @@ -512,6 +514,9 @@ async def main(
parser.add_argument(
"--zero-rtt", action="store_true", help="try to send requests using 0-RTT"
)
parser.add_argument(
"--congestion-control", type=str, help="which congestion control algorithm to use (reno, cubic)"
)

args = parser.parse_args()

Expand Down Expand Up @@ -545,6 +550,13 @@ async def main(
configuration.quic_logger = QuicFileLogger(args.quic_log)
if args.secrets_log:
configuration.secrets_log_file = open(args.secrets_log, "a")
if args.congestion_control:
if args.congestion_control == "cubic":
configuration.congestion_control = CubicCongestionControl()
elif args.congestion_control == "reno":
configuration.congestion_control = RenoCongestionControl()
else:
raise Exception("Invalid congestion control algorithm")
if args.session_ticket:
try:
with open(args.session_ticket, "rb") as fp:
Expand Down
13 changes: 13 additions & 0 deletions examples/http3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from aioquic.quic.events import DatagramFrameReceived, ProtocolNegotiated, QuicEvent
from aioquic.quic.logger import QuicFileLogger
from aioquic.tls import SessionTicket
from aioquic.quic.congestion.reno import RenoCongestionControl
from aioquic.quic.congestion.cubic import CubicCongestionControl

try:
import uvloop
Expand Down Expand Up @@ -556,6 +558,9 @@ async def main(
parser.add_argument(
"-v", "--verbose", action="store_true", help="increase logging verbosity"
)
parser.add_argument(
"--congestion-control", type=str, help="which congestion control algorithm to use (reno, cubic)"
)
args = parser.parse_args()

logging.basicConfig(
Expand Down Expand Up @@ -592,6 +597,14 @@ async def main(
# load SSL certificate and key
configuration.load_cert_chain(args.certificate, args.private_key)

if args.congestion_control:
if args.congestion_control == "cubic":
configuration.congestion_control = CubicCongestionControl()
elif args.congestion_control == "reno":
configuration.congestion_control = RenoCongestionControl()
else:
raise Exception("Invalid congestion control algorithm")

if uvloop is not None:
uvloop.install()

Expand Down
6 changes: 6 additions & 0 deletions src/aioquic/quic/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from os import PathLike
from re import split
from typing import Any, List, Optional, TextIO, Union
from .recovery import QuicCongestionControl

from ..tls import (
CipherSuite,
Expand Down Expand Up @@ -82,6 +83,11 @@ class QuicConfiguration:
The TLS session ticket which should be used for session resumption.
"""

congestion_control: Optional[QuicCongestionControl] = None
"""
Selection for a congestion control algorithm
"""

cadata: Optional[bytes] = None
cafile: Optional[str] = None
capath: Optional[str] = None
Expand Down
145 changes: 145 additions & 0 deletions src/aioquic/quic/congestion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from ..packet_builder import QuicSentPacket
from typing import Iterable, Optional, Dict, Any
from datetime import datetime
from enum import Enum

K_GRANULARITY = 0.001 # seconds

# congestion control
K_INITIAL_WINDOW = 10
K_MINIMUM_WINDOW = 2
K_LOSS_REDUCTION_FACTOR = 0.5

class CongestionEvent(Enum):
ACK=0
PACKET_SENT=1
PACKET_EXPIRED=2
PACKET_LOST=3
RTT_MEASURED=4

class QuicRttMonitor:
"""
Roundtrip time monitor for HyStart.
"""

def __init__(self) -> None:
self._increases = 0
self._last_time = None
self._ready = False
self._size = 5

self._filtered_min: Optional[float] = None

self._sample_idx = 0
self._sample_max: Optional[float] = None
self._sample_min: Optional[float] = None
self._sample_time = 0.0
self._samples = [0.0 for i in range(self._size)]

def add_rtt(self, rtt: float) -> None:
self._samples[self._sample_idx] = rtt
self._sample_idx += 1

if self._sample_idx >= self._size:
self._sample_idx = 0
self._ready = True

if self._ready:
self._sample_max = self._samples[0]
self._sample_min = self._samples[0]
for sample in self._samples[1:]:
if sample < self._sample_min:
self._sample_min = sample
elif sample > self._sample_max:
self._sample_max = sample

def is_rtt_increasing(self, rtt: float, now: float) -> bool:
if now > self._sample_time + K_GRANULARITY:
self.add_rtt(rtt)
self._sample_time = now

if self._ready:
if self._filtered_min is None or self._filtered_min > self._sample_max:
self._filtered_min = self._sample_max

delta = self._sample_min - self._filtered_min
if delta * 4 >= self._filtered_min:
self._increases += 1
if self._increases >= self._size:
return True
elif delta > 0:
self._increases = 0
return False

class QuicCongestionControl:

def __init__(self, *, max_datagram_size : int, callback=None, fixed_cwnd = 10*1024*1024) -> None:
self.callback = callback # a callback argument that is called when an event occurs
# 10 GB window or custom fixed size window (shouldn't be used in real network !, use a real CCA instead)
self._max_datagram_size = max_datagram_size
self.cwnd = fixed_cwnd
self.data_in_flight = 0

def set_recovery(self, recovery):
# recovery is a QuicPacketRecovery instance
self.recovery = recovery

def on_packet_acked(self, packet: QuicSentPacket, now : float):
if self.callback:
self.callback(CongestionEvent.ACK, self)
if type(self) == QuicCongestionControl:
# don't call this if it is a superclass that runs
self.data_in_flight -= packet.sent_bytes

def on_packet_sent(self, packet: QuicSentPacket, now : float) -> None:
if self.callback:
self.callback(CongestionEvent.PACKET_SENT, self)
if type(self) == QuicCongestionControl:
# don't call this if it is a superclass that runs
self.data_in_flight += packet.sent_bytes

def on_packets_expired(self, packets: Iterable[QuicSentPacket]) -> None:
if self.callback:
self.callback(CongestionEvent.PACKET_EXPIRED, self)
if type(self) == QuicCongestionControl:
for packet in packets:
# don't call this if it is a superclass that runs
self.data_in_flight -= packet.sent_bytes

def on_packets_lost(self, packets: Iterable[QuicSentPacket], now: float) -> None:
if self.callback:
self.callback(CongestionEvent.PACKET_LOST, self)
if type(self) == QuicCongestionControl:
for packet in packets:
# don't call this if it is a superclass that runs
self.data_in_flight -= packet.sent_bytes

def on_rtt_measurement(self, latest_rtt: float, now: float) -> None:
if self.callback:
self.callback(CongestionEvent.RTT_MEASURED, self)

def get_congestion_window(self) -> int:
# return the cwnd in number of bytes
return self.cwnd

def _set_congestion_window(self, value):
self.cwnd = value

def get_ssthresh(self) -> Optional[int]:
pass

def get_bytes_in_flight(self) -> int:
return self.data_in_flight

def log_callback(self) -> Dict[str, Any]:
# a callback called when a recovery happens
# The data object will be saved in the log file, so feel free to add
# any attribute you want to track
data: Dict[str, Any] = {
"bytes_in_flight": self.get_bytes_in_flight(),
"cwnd": self.get_congestion_window(),
}
if self.get_ssthresh() is not None:
data["ssthresh"] = self.get_ssthresh()

return data
Loading

0 comments on commit 4cf275a

Please sign in to comment.