diff --git a/src/core/net/tcp.cc b/src/core/net/tcp.cc new file mode 100644 index 0000000..4374879 --- /dev/null +++ b/src/core/net/tcp.cc @@ -0,0 +1,44 @@ +/** + * @file tcp.cc + * @brief TCP header implementation. + */ + +#include + +namespace juggler { +namespace net { + +std::string Tcp::ToString() const { + return juggler::utils::Format( + "[TCP: src_port %zu, dst_port %zu, seq %u, ack %u, " + "offset %u, flags %s, win %zu, cksum 0x%04x, urg %zu]", + src_port.port.value(), dst_port.port.value(), seq_num.value(), + ack_num.value(), GetDataOffset(), FlagsToString().c_str(), + window.value(), cksum.value(), urgent_ptr.value()); +} + +std::string Tcp::FlagsToString() const { + std::string result; + if (HasFlag(kFin)) result += "FIN|"; + if (HasFlag(kSyn)) result += "SYN|"; + if (HasFlag(kRst)) result += "RST|"; + if (HasFlag(kPsh)) result += "PSH|"; + if (HasFlag(kAck)) result += "ACK|"; + if (HasFlag(kUrg)) result += "URG|"; + if (HasFlag(kEce)) result += "ECE|"; + if (HasFlag(kCwr)) result += "CWR|"; + + // Remove trailing '|' if present + if (!result.empty() && result.back() == '|') { + result.pop_back(); + } + + if (result.empty()) { + result = "NONE"; + } + + return result; +} + +} // namespace net +} // namespace juggler diff --git a/src/include/flow_key.h b/src/include/flow_key.h index 96aefe5..694378d 100644 --- a/src/include/flow_key.h +++ b/src/include/flow_key.h @@ -1,14 +1,26 @@ /** * @file flow_key.h + * @brief Flow key definitions for UDP and TCP flows. */ #ifndef SRC_INCLUDE_FLOW_KEY_H_ #define SRC_INCLUDE_FLOW_KEY_H_ #include +#include #include namespace juggler { namespace net { + +/** + * @enum Protocol + * @brief Transport protocol identifier for flow keys. + */ +enum class Protocol : uint8_t { + kUdp = Ipv4::Proto::kUdp, // UDP protocol (17) + kTcp = Ipv4::Proto::kTcp, // TCP protocol (6) +}; + namespace flow { struct Listener { @@ -101,6 +113,110 @@ struct Key { static_assert(sizeof(Key) == 12, "Flow key size is not 12 bytes."); } // namespace flow + +/** + * @namespace tcp_flow + * @brief TCP-specific flow key definitions. + */ +namespace tcp_flow { + +/** + * @struct Listener + * @brief TCP listener endpoint (IP address and port). + */ +struct Listener { + using Ipv4 = juggler::net::Ipv4; + using Tcp = juggler::net::Tcp; + Listener(const Listener& other) = default; + + /** + * @brief Construct a new Listener object. + * + * @param local_addr Local IP address (in network byte order). + * @param local_port Local TCP port (in network byte order). + */ + Listener(const Ipv4::Address& local_addr, const Tcp::Port& local_port) + : addr(local_addr), port(local_port) {} + + /** + * @brief Construct a new Listener object. + * + * @param local_addr Local IP address (in host byte order). + * @param local_port Local TCP port (in host byte order). + */ + Listener(const uint32_t local_addr, const uint16_t local_port) + : addr(local_addr), port(local_port) {} + + bool operator==(const Listener& other) const { + return addr == other.addr && port == other.port; + } + + const Ipv4::Address addr; + const Tcp::Port port; +}; +static_assert(sizeof(Listener) == 6, "TCP Listener size is not 6 bytes."); + +/** + * @struct Key + * @brief TCP flow key: corresponds to the 5-tuple (TCP is the protocol). + */ +struct Key { + using Ipv4 = juggler::net::Ipv4; + using Tcp = juggler::net::Tcp; + Key(const Key& other) = default; + + /** + * @brief Construct a new Key object. + * + * @param local_addr Local IP address (in network byte order). + * @param local_port Local TCP port (in network byte order). + * @param remote_addr Remote IP address (in network byte order). + * @param remote_port Remote TCP port (in network byte order). + */ + Key(const Ipv4::Address& local_addr, const Tcp::Port& local_port, + const Ipv4::Address& remote_addr, const Tcp::Port& remote_port) + : local_addr(local_addr), + local_port(local_port), + remote_addr(remote_addr), + remote_port(remote_port) {} + + /** + * @brief Construct a new Key object. + * + * @param local_addr Local IP address (in host byte order). + * @param local_port Local TCP port (in host byte order). + * @param remote_addr Remote IP address (in host byte order). + * @param remote_port Remote UDP port (in host byte order). + */ + Key(const uint32_t local_addr, const uint16_t local_port, + const uint32_t remote_addr, const uint16_t remote_port) + : local_addr(local_addr), + local_port(local_port), + remote_addr(remote_addr), + remote_port(remote_port) {} + + bool operator==(const Key& other) const { + return local_addr == other.local_addr && local_port == other.local_port && + remote_addr == other.remote_addr && remote_port == other.remote_port; + } + + std::string ToString() const { + return utils::Format("[TCP %s:%hu <-> %s:%hu]", + remote_addr.ToString().c_str(), + remote_port.port.value(), + local_addr.ToString().c_str(), + local_port.port.value()); + } + + const Ipv4::Address local_addr; + const Tcp::Port local_port; + const Ipv4::Address remote_addr; + const Tcp::Port remote_port; +}; +static_assert(sizeof(Key) == 12, "TCP flow key size is not 12 bytes."); + +} // namespace tcp_flow + } // namespace net } // namespace juggler @@ -127,6 +243,23 @@ struct hash { } }; +// TCP flow hash specializations +template <> +struct hash { + size_t operator()(const juggler::net::tcp_flow::Listener& listener) const { + return juggler::utils::hash( + reinterpret_cast(&listener), sizeof(listener)); + } +}; + +template <> +struct hash { + size_t operator()(const juggler::net::tcp_flow::Key& key) const { + return juggler::utils::hash(reinterpret_cast(&key), + sizeof(key)); + } +}; + } // namespace std #endif // SRC_INCLUDE_FLOW_KEY_H_ diff --git a/src/include/tcp.h b/src/include/tcp.h new file mode 100644 index 0000000..3d23133 --- /dev/null +++ b/src/include/tcp.h @@ -0,0 +1,180 @@ +/** + * @file tcp.h + * @brief TCP (Transmission Control Protocol) header definition. + */ + +#ifndef SRC_INCLUDE_TCP_H_ +#define SRC_INCLUDE_TCP_H_ + +#include +#include + +#include + +namespace juggler { +namespace net { + +/** + * @struct Tcp + * @brief TCP header structure (RFC 793). + * + * The TCP header is 20 bytes without options: + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Source Port | Destination Port | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Acknowledgment Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Data | |U|A|P|R|S|F| | + * | Offset| Reserved |R|C|S|S|Y|I| Window | + * | | |G|K|H|T|N|N| | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Checksum | Urgent Pointer | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +struct __attribute__((packed)) Tcp { + static constexpr uint8_t kMinHeaderLen = 20; + static constexpr uint8_t kMaxHeaderLen = 60; + + /** + * @struct Port + * @brief TCP port wrapper (identical to UDP port structure for consistency). + */ + struct __attribute__((packed)) Port { + static const uint8_t kSize = 2; + Port() = default; + Port(uint16_t tcp_port) { port = be16_t(tcp_port); } + bool operator==(const Port &rhs) const { return port == rhs.port; } + bool operator==(be16_t rhs) const { return rhs == port; } + bool operator!=(const Port &rhs) const { return port != rhs.port; } + bool operator!=(be16_t rhs) const { return rhs != port; } + + be16_t port; + }; + + /** + * @enum Flags + * @brief TCP control flags. + */ + enum Flags : uint8_t { + kFin = 0x01, // Finish - no more data from sender + kSyn = 0x02, // Synchronize - initiate connection + kRst = 0x04, // Reset - abort connection + kPsh = 0x08, // Push - push data to application + kAck = 0x10, // Acknowledgment field is valid + kUrg = 0x20, // Urgent pointer field is valid + kEce = 0x40, // ECN-Echo (RFC 3168) + kCwr = 0x80, // Congestion Window Reduced (RFC 3168) + }; + + /** + * @brief Get the data offset (header length) in 32-bit words. + * @return Header length in 32-bit words (5-15). + */ + uint8_t GetDataOffset() const { return (data_offset_reserved >> 4) & 0x0F; } + + /** + * @brief Get the header length in bytes. + * @return Header length in bytes (20-60). + */ + uint8_t GetHeaderLength() const { return GetDataOffset() * 4; } + + /** + * @brief Set the data offset (header length) in 32-bit words. + * @param offset Header length in 32-bit words (5-15). + */ + void SetDataOffset(uint8_t offset) { + data_offset_reserved = (offset << 4) | (data_offset_reserved & 0x0F); + } + + /** + * @brief Get TCP flags. + * @return TCP flags byte. + */ + uint8_t GetFlags() const { return flags; } + + /** + * @brief Set TCP flags. + * @param f Flags to set. + */ + void SetFlags(uint8_t f) { flags = f; } + + /** + * @brief Check if a specific flag is set. + * @param flag Flag to check. + * @return true if flag is set, false otherwise. + */ + bool HasFlag(Flags flag) const { return (flags & flag) != 0; } + + /** + * @brief Set a specific flag. + * @param flag Flag to set. + */ + void SetFlag(Flags flag) { flags |= flag; } + + /** + * @brief Clear a specific flag. + * @param flag Flag to clear. + */ + void ClearFlag(Flags flag) { flags &= ~flag; } + + /** + * @brief Convert TCP header to human-readable string. + * @return String representation of the TCP header. + */ + std::string ToString() const; + + /** + * @brief Get string representation of TCP flags. + * @return String with flag names. + */ + std::string FlagsToString() const; + + // TCP Header Fields + Port src_port; // Source port + Port dst_port; // Destination port + be32_t seq_num; // Sequence number + be32_t ack_num; // Acknowledgment number + uint8_t data_offset_reserved; // Data offset (4 bits) + Reserved (4 bits) + uint8_t flags; // Control flags + be16_t window; // Window size + be16_t cksum; // Checksum + be16_t urgent_ptr; // Urgent pointer +}; + +static_assert(sizeof(Tcp) == 20, "TCP header size must be 20 bytes"); + +/** + * @brief Bitwise OR operator for TCP flags. + */ +inline Tcp::Flags operator|(Tcp::Flags lhs, Tcp::Flags rhs) { + return static_cast(static_cast(lhs) | + static_cast(rhs)); +} + +/** + * @brief Bitwise AND operator for TCP flags. + */ +inline Tcp::Flags operator&(Tcp::Flags lhs, Tcp::Flags rhs) { + return static_cast(static_cast(lhs) & + static_cast(rhs)); +} + +} // namespace net +} // namespace juggler + +namespace std { +template <> +struct hash { + std::size_t operator()(const juggler::net::Tcp::Port &port) const { + return juggler::utils::hash( + reinterpret_cast(&port.port), + sizeof(port.port.raw_value())); + } +}; +} // namespace std + +#endif // SRC_INCLUDE_TCP_H_ diff --git a/src/include/tcp_cc.h b/src/include/tcp_cc.h new file mode 100644 index 0000000..c5caa09 --- /dev/null +++ b/src/include/tcp_cc.h @@ -0,0 +1,539 @@ +/** + * @file tcp_cc.h + * @brief TCP Congestion Control implementation. + * + * This file implements standard TCP congestion control algorithms: + * - Slow Start + * - Congestion Avoidance + * - Fast Retransmit + * - Fast Recovery (RFC 5681) + * - NewReno-style recovery + */ + +#ifndef SRC_INCLUDE_TCP_CC_H_ +#define SRC_INCLUDE_TCP_CC_H_ + +#include +#include +#include +#include + +#include "utils.h" + +namespace juggler { +namespace net { +namespace tcp { + +/** + * @brief Sequence number comparison functions for wrap-around handling. + * Uses signed arithmetic to handle 32-bit sequence number wrap-around. + */ +constexpr bool seqno_lt(uint32_t a, uint32_t b) { + return static_cast(a - b) < 0; +} +constexpr bool seqno_le(uint32_t a, uint32_t b) { + return static_cast(a - b) <= 0; +} +constexpr bool seqno_eq(uint32_t a, uint32_t b) { + return static_cast(a - b) == 0; +} +constexpr bool seqno_ge(uint32_t a, uint32_t b) { + return static_cast(a - b) >= 0; +} +constexpr bool seqno_gt(uint32_t a, uint32_t b) { + return static_cast(a - b) > 0; +} + +/** + * @enum TcpState + * @brief TCP connection states (RFC 793). + */ +enum class TcpState : uint8_t { + kClosed, + kListen, + kSynSent, + kSynReceived, + kEstablished, + kFinWait1, + kFinWait2, + kCloseWait, + kClosing, + kLastAck, + kTimeWait, +}; + +/** + * @brief Convert TCP state to string. + */ +inline constexpr const char* TcpStateToString(TcpState state) { + switch (state) { + case TcpState::kClosed: + return "CLOSED"; + case TcpState::kListen: + return "LISTEN"; + case TcpState::kSynSent: + return "SYN_SENT"; + case TcpState::kSynReceived: + return "SYN_RECEIVED"; + case TcpState::kEstablished: + return "ESTABLISHED"; + case TcpState::kFinWait1: + return "FIN_WAIT_1"; + case TcpState::kFinWait2: + return "FIN_WAIT_2"; + case TcpState::kCloseWait: + return "CLOSE_WAIT"; + case TcpState::kClosing: + return "CLOSING"; + case TcpState::kLastAck: + return "LAST_ACK"; + case TcpState::kTimeWait: + return "TIME_WAIT"; + default: + return "UNKNOWN"; + } +} + +/** + * @enum CongestionState + * @brief TCP congestion control states. + */ +enum class CongestionState : uint8_t { + kSlowStart, // Exponential cwnd growth + kCongestionAvoidance, // Linear cwnd growth (AIMD) + kFastRecovery, // After fast retransmit (NewReno) +}; + +/** + * @brief Convert congestion state to string. + */ +inline constexpr const char* CongestionStateToString(CongestionState state) { + switch (state) { + case CongestionState::kSlowStart: + return "SLOW_START"; + case CongestionState::kCongestionAvoidance: + return "CONGESTION_AVOIDANCE"; + case CongestionState::kFastRecovery: + return "FAST_RECOVERY"; + default: + return "UNKNOWN"; + } +} + +/** + * @struct RttEstimator + * @brief RTT estimation using Jacobson/Karels algorithm (RFC 6298). + */ +struct RttEstimator { + static constexpr int64_t kInitialRttUs = 1000000; // 1 second initial RTT + static constexpr int64_t kMinRtoUs = 200000; // 200ms minimum RTO + static constexpr int64_t kMaxRtoUs = 60000000; // 60 seconds maximum RTO + static constexpr int kRtoClockGranularityUs = 1000; // 1ms granularity + + // Alpha = 1/8, Beta = 1/4 (RFC 6298 recommended values) + static constexpr int kAlphaShift = 3; // 1/8 + static constexpr int kBetaShift = 2; // 1/4 + + int64_t srtt_us{0}; // Smoothed RTT (microseconds) + int64_t rttvar_us{0}; // RTT variance (microseconds) + int64_t rto_us{kInitialRttUs}; // Retransmission timeout (microseconds) + bool has_measurement{false}; + + /** + * @brief Update RTT estimate with a new measurement. + * @param rtt_us Measured RTT in microseconds. + */ + void UpdateRtt(int64_t rtt_us) { + if (!has_measurement) { + // First RTT measurement + srtt_us = rtt_us; + rttvar_us = rtt_us / 2; + has_measurement = true; + } else { + // Jacobson/Karels algorithm + int64_t delta = rtt_us - srtt_us; + srtt_us += delta >> kAlphaShift; + int64_t abs_delta = delta >= 0 ? delta : -delta; + rttvar_us += (abs_delta - rttvar_us) >> kBetaShift; + } + + // RTO = SRTT + max(G, 4 * RTTVAR) + int64_t k_rttvar = rttvar_us << 2; // 4 * RTTVAR + int64_t timeout = srtt_us + std::max(static_cast(kRtoClockGranularityUs), k_rttvar); + rto_us = std::clamp(timeout, kMinRtoUs, kMaxRtoUs); + } + + /** + * @brief Apply exponential backoff to RTO (on timeout). + */ + void BackoffRto() { + rto_us = std::min(rto_us * 2, kMaxRtoUs); + } + + /** + * @brief Reset estimator to initial state. + */ + void Reset() { + srtt_us = 0; + rttvar_us = 0; + rto_us = kInitialRttUs; + has_measurement = false; + } + + std::string ToString() const { + return utils::Format("srtt=%ldus, rttvar=%ldus, rto=%ldus", + srtt_us, rttvar_us, rto_us); + } +}; + +/** + * @struct TcpControlBlock + * @brief TCP Protocol Control Block (PCB) for congestion control. + * + * Implements standard TCP congestion control with: + * - Slow start with initial window + * - Congestion avoidance (AIMD) + * - Fast retransmit and fast recovery (NewReno) + * - SACK support + */ +struct TcpControlBlock { + // Constants + static constexpr uint32_t kInitialCwnd = 10; // Initial cwnd (RFC 6928) + static constexpr uint32_t kMinCwnd = 2; // Minimum cwnd + static constexpr uint32_t kInitialSsthresh = 65535; // Initial ssthresh + static constexpr uint32_t kDupAckThreshold = 3; // Dup ACKs for fast retransmit + static constexpr uint16_t kDefaultMss = 1460; // Default MSS + static constexpr std::size_t kSackBlocksMax = 4; // Max SACK blocks + static constexpr int kRtoDisabled = -1; + static constexpr int kMaxRexmits = 15; // Max retransmissions + + // Send sequence variables (RFC 793) + uint32_t snd_una{0}; // Oldest unacknowledged sequence number + uint32_t snd_nxt{0}; // Next sequence number to send + uint32_t snd_wnd{65535}; // Send window (advertised by receiver) + uint32_t iss{0}; // Initial send sequence number + + // Receive sequence variables + uint32_t rcv_nxt{0}; // Next expected sequence number + uint32_t rcv_wnd{65535}; // Receive window (advertised to sender) + uint32_t irs{0}; // Initial receive sequence number + + // Congestion control variables + uint32_t cwnd{kInitialCwnd * kDefaultMss}; // Congestion window (bytes) + uint32_t ssthresh{kInitialSsthresh}; // Slow start threshold + uint16_t mss{kDefaultMss}; // Maximum segment size + CongestionState cc_state{CongestionState::kSlowStart}; + + // Fast retransmit/recovery state + uint32_t dup_ack_count{0}; // Duplicate ACK counter + uint32_t recover{0}; // Recovery point (NewReno) + uint32_t high_rxt{0}; // Highest retransmitted sequence + + // SACK support + struct SackBlock { + uint32_t left{0}; // Left edge of SACK block + uint32_t right{0}; // Right edge of SACK block + bool valid{false}; + }; + SackBlock sack_blocks[kSackBlocksMax]; + uint8_t sack_block_count{0}; + + // RTT estimation + RttEstimator rtt_est; + + // RTO timer state + int rto_timer{kRtoDisabled}; + int rto_rexmits{0}; + + // Statistics + uint64_t bytes_sent{0}; + uint64_t bytes_acked{0}; + uint64_t packets_retransmitted{0}; + + /** + * @brief Initialize the control block with an initial sequence number. + * @param initial_seq Initial sequence number. + */ + void Initialize(uint32_t initial_seq) { + iss = initial_seq; + snd_una = initial_seq; + snd_nxt = initial_seq; + cwnd = kInitialCwnd * mss; + ssthresh = kInitialSsthresh; + cc_state = CongestionState::kSlowStart; + dup_ack_count = 0; + rto_timer = kRtoDisabled; + rto_rexmits = 0; + } + + /** + * @brief Get the current send sequence number and advance. + * @return Current snd_nxt before advancing. + */ + uint32_t GetAndAdvanceSndNxt(uint32_t len = 1) { + uint32_t seq = snd_nxt; + snd_nxt += len; + return seq; + } + + /** + * @brief Calculate effective send window (min of cwnd and receiver window). + * @return Effective window in bytes. + */ + uint32_t EffectiveWindow() const { + uint32_t wnd = std::min(cwnd, snd_wnd); + uint32_t flight_size = snd_nxt - snd_una; + return flight_size >= wnd ? 0 : wnd - flight_size; + } + + /** + * @brief Get the number of bytes in flight. + * @return Bytes in flight. + */ + uint32_t FlightSize() const { return snd_nxt - snd_una; } + + /** + * @brief Process an ACK and update congestion control state. + * @param ack_num Acknowledgment number from received ACK. + * @param is_dup True if this is a duplicate ACK. + * @return Number of newly acknowledged bytes. + */ + uint32_t ProcessAck(uint32_t ack_num, bool is_dup) { + uint32_t newly_acked = 0; + + if (is_dup) { + return ProcessDuplicateAck(ack_num); + } + + // Check for valid ACK + if (seqno_gt(ack_num, snd_nxt)) { + // ACK for data not yet sent - invalid + return 0; + } + + if (seqno_le(ack_num, snd_una)) { + // Old/duplicate ACK + return ProcessDuplicateAck(ack_num); + } + + // New ACK - calculate bytes acknowledged + newly_acked = ack_num - snd_una; + snd_una = ack_num; + bytes_acked += newly_acked; + + // Reset duplicate ACK counter + dup_ack_count = 0; + + // Update congestion control based on state + switch (cc_state) { + case CongestionState::kSlowStart: + OnAckSlowStart(newly_acked); + break; + case CongestionState::kCongestionAvoidance: + OnAckCongestionAvoidance(newly_acked); + break; + case CongestionState::kFastRecovery: + OnAckFastRecovery(ack_num, newly_acked); + break; + } + + // Reset RTO timer if there's still outstanding data + if (snd_una == snd_nxt) { + RtoDisable(); + } else { + RtoReset(); + } + rto_rexmits = 0; + + return newly_acked; + } + + /** + * @brief Handle a timeout event. + */ + void OnTimeout() { + // Set ssthresh to half of flight size (RFC 5681) + ssthresh = std::max(FlightSize() / 2, 2 * mss); + + // Reset cwnd to 1 MSS (or IW for loss-based) + cwnd = mss; + + // Enter slow start + cc_state = CongestionState::kSlowStart; + + // Clear duplicate ACK count + dup_ack_count = 0; + + // Backoff RTO + rtt_est.BackoffRto(); + rto_rexmits++; + packets_retransmitted++; + } + + /** + * @brief Check if fast retransmit should be triggered. + * @return True if fast retransmit condition is met. + */ + bool ShouldFastRetransmit() const { + return dup_ack_count >= kDupAckThreshold && + cc_state != CongestionState::kFastRecovery; + } + + /** + * @brief Check if max retransmissions have been reached. + */ + bool MaxRexmitsReached() const { return rto_rexmits >= kMaxRexmits; } + + // RTO timer management + bool RtoDisabled() const { return rto_timer == kRtoDisabled; } + bool RtoExpired() const { + return rto_timer >= 0 && + static_cast(rto_timer) * 100000 >= rtt_est.rto_us; + } + void RtoEnable() { rto_timer = 0; } + void RtoDisable() { rto_timer = kRtoDisabled; } + void RtoReset() { rto_timer = 0; } + void RtoAdvance() { + if (rto_timer >= 0) rto_timer++; + } + + /** + * @brief Update receive window. + * @param window Advertised window from sender. + */ + void UpdateSndWnd(uint16_t window) { snd_wnd = window; } + + /** + * @brief Update MSS. + * @param new_mss New maximum segment size. + */ + void UpdateMss(uint16_t new_mss) { + mss = new_mss; + // Adjust cwnd to be a multiple of MSS + if (cwnd < mss) cwnd = mss; + } + + /** + * @brief Add a SACK block. + * @param left Left edge of the block. + * @param right Right edge of the block. + */ + void AddSackBlock(uint32_t left, uint32_t right) { + if (sack_block_count < kSackBlocksMax) { + sack_blocks[sack_block_count].left = left; + sack_blocks[sack_block_count].right = right; + sack_blocks[sack_block_count].valid = true; + sack_block_count++; + } + } + + /** + * @brief Clear all SACK blocks. + */ + void ClearSackBlocks() { + for (size_t i = 0; i < kSackBlocksMax; i++) { + sack_blocks[i].valid = false; + } + sack_block_count = 0; + } + + std::string ToString() const { + return utils::Format( + "[TCP CC] snd_una=%u, snd_nxt=%u, cwnd=%u, ssthresh=%u, " + "state=%s, dup_acks=%u, flight=%u, eff_wnd=%u, %s", + snd_una, snd_nxt, cwnd, ssthresh, + CongestionStateToString(cc_state), dup_ack_count, + FlightSize(), EffectiveWindow(), + rtt_est.ToString().c_str()); + } + + private: + /** + * @brief Handle ACK in slow start phase. + */ + void OnAckSlowStart(uint32_t newly_acked) { + // Exponential growth: increase cwnd by number of bytes ACKed + cwnd += newly_acked; + + // Check if we should transition to congestion avoidance + if (cwnd >= ssthresh) { + cc_state = CongestionState::kCongestionAvoidance; + } + } + + /** + * @brief Handle ACK in congestion avoidance phase. + */ + void OnAckCongestionAvoidance(uint32_t newly_acked) { + // Linear growth: increase cwnd by MSS^2/cwnd for each ACK + // This approximates: cwnd += MSS per RTT + uint32_t increment = (mss * newly_acked) / cwnd; + if (increment == 0) increment = 1; + cwnd += increment; + } + + /** + * @brief Handle ACK in fast recovery phase (NewReno). + */ + void OnAckFastRecovery(uint32_t ack_num, uint32_t newly_acked) { + if (seqno_ge(ack_num, recover)) { + // Full ACK - exit fast recovery + cwnd = std::min(ssthresh, FlightSize() + mss); + cc_state = CongestionState::kCongestionAvoidance; + } else { + // Partial ACK - stay in fast recovery + // Deflate cwnd by amount of new data ACKed + cwnd -= newly_acked; + // Add back one segment (to compensate for partial ACK) + cwnd += mss; + // Retransmit next unacked segment + packets_retransmitted++; + } + } + + /** + * @brief Handle duplicate ACK. + * @return Bytes acknowledged (always 0 for dup ACKs). + */ + uint32_t ProcessDuplicateAck(uint32_t ack_num) { + dup_ack_count++; + + if (cc_state == CongestionState::kFastRecovery) { + // In fast recovery, inflate cwnd + cwnd += mss; + return 0; + } + + if (dup_ack_count == kDupAckThreshold) { + // Enter fast retransmit/recovery + EnterFastRecovery(); + } + + return 0; + } + + /** + * @brief Enter fast recovery mode. + */ + void EnterFastRecovery() { + // Set ssthresh to half of flight size + ssthresh = std::max(FlightSize() / 2, 2 * mss); + + // Set recovery point + recover = snd_nxt; + + // Set cwnd = ssthresh + 3*MSS (RFC 5681) + cwnd = ssthresh + kDupAckThreshold * mss; + + // Enter fast recovery state + cc_state = CongestionState::kFastRecovery; + + packets_retransmitted++; + } +}; + +} // namespace tcp +} // namespace net +} // namespace juggler + +#endif // SRC_INCLUDE_TCP_CC_H_ diff --git a/src/include/tcp_flow.h b/src/include/tcp_flow.h new file mode 100644 index 0000000..7b3fd48 --- /dev/null +++ b/src/include/tcp_flow.h @@ -0,0 +1,821 @@ +/** + * @file tcp_flow.h + * @brief Class to abstract the components and functionality of a single TCP + * flow. + */ + +#ifndef SRC_INCLUDE_TCP_FLOW_H_ +#define SRC_INCLUDE_TCP_FLOW_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace juggler { +namespace net { +namespace tcp_flow { + +/** + * @class TXTracking + * @brief Tracking for message buffers that are sent to the network. + * Manages the chain of outgoing message buffers and their acknowledgment + * status. + */ +class TXTracking { + public: + TXTracking() = delete; + explicit TXTracking(shm::Channel* channel) + : channel_(CHECK_NOTNULL(channel)), + oldest_unacked_msgbuf_(nullptr), + oldest_unsent_msgbuf_(nullptr), + last_msgbuf_(nullptr), + num_unsent_msgbufs_(0), + num_tracked_msgbufs_(0) {} + + uint32_t NumUnsentMsgbufs() const { return num_unsent_msgbufs_; } + shm::MsgBuf* GetOldestUnackedMsgBuf() const { return oldest_unacked_msgbuf_; } + + void ReceiveAcks(uint32_t num_acked_pkts) { + shm::MsgBufBatch to_free; + while (num_acked_pkts) { + auto msgbuf = oldest_unacked_msgbuf_; + DCHECK(msgbuf != nullptr); + if (msgbuf != last_msgbuf_) { + DCHECK_NE(oldest_unacked_msgbuf_, oldest_unsent_msgbuf_) + << "Releasing an unsent msgbuf!"; + oldest_unacked_msgbuf_ = channel_->GetMsgBuf(msgbuf->next()); + } else { + oldest_unacked_msgbuf_ = nullptr; + last_msgbuf_ = nullptr; + } + to_free.Append(msgbuf, msgbuf->index()); + if (to_free.IsFull()) { + num_tracked_msgbufs_ -= to_free.GetSize(); + CHECK(channel_->MsgBufBulkFree(&to_free)); + } + num_acked_pkts--; + } + + num_tracked_msgbufs_ -= to_free.GetSize(); + CHECK(channel_->MsgBufBulkFree(&to_free)); + } + + void Append(shm::MsgBuf* msgbuf) { + DCHECK(msgbuf->is_first()); + // Append the message at the end of the chain of buffers, if any. + if (last_msgbuf_ == nullptr) { + // This is the first pending message buffer in the flow. + DCHECK(oldest_unsent_msgbuf_ == nullptr); + last_msgbuf_ = channel_->GetMsgBuf(msgbuf->last()); + oldest_unsent_msgbuf_ = msgbuf; + oldest_unacked_msgbuf_ = msgbuf; + } else { + // This is not the first message buffer in the flow. + DCHECK(oldest_unacked_msgbuf_ != nullptr); + // Let's enqueue the new message buffer at the end of the chain. + last_msgbuf_->link(msgbuf); + DCHECK(!(last_msgbuf_->is_last() && last_msgbuf_->is_sg())); + // Update the last buffer pointer to point to the current buffer. + last_msgbuf_ = channel_->GetMsgBuf(msgbuf->last()); + if (oldest_unsent_msgbuf_ == nullptr) oldest_unsent_msgbuf_ = msgbuf; + } + + const auto msg_length = msgbuf->msg_length(); + const auto effective_buffer_size = channel_->GetUsableBufSize(); + const auto msg_buffers_nr = + (msg_length + effective_buffer_size - 1) / effective_buffer_size; + num_unsent_msgbufs_ += msg_buffers_nr; + num_tracked_msgbufs_ += msg_buffers_nr; + } + + std::optional GetAndUpdateOldestUnsent() { + if (oldest_unsent_msgbuf_ == nullptr) { + DCHECK_EQ(NumUnsentMsgbufs(), 0); + return std::nullopt; + } + + auto msgbuf = oldest_unsent_msgbuf_; + if (oldest_unsent_msgbuf_ != last_msgbuf_) { + oldest_unsent_msgbuf_ = + channel_->GetMsgBuf(oldest_unsent_msgbuf_->next()); + } else { + oldest_unsent_msgbuf_ = nullptr; + } + + num_unsent_msgbufs_--; + return msgbuf; + } + + private: + uint32_t NumTrackedMsgbufs() const { return num_tracked_msgbufs_; } + const shm::MsgBuf* GetLastMsgBuf() const { return last_msgbuf_; } + const shm::MsgBuf* GetOldestUnsentMsgBuf() const { + return oldest_unsent_msgbuf_; + } + + shm::Channel* channel_; + shm::MsgBuf* oldest_unacked_msgbuf_; + shm::MsgBuf* oldest_unsent_msgbuf_; + shm::MsgBuf* last_msgbuf_; + uint32_t num_unsent_msgbufs_; + uint32_t num_tracked_msgbufs_; +}; + +/** + * @class RXTracking + * @brief Tracking for message buffers received from the network. + * Handles out-of-order reception and delivers complete messages to the + * application. + */ +class RXTracking { + public: + struct reasm_queue_ent_t { + shm::MsgBuf* msgbuf; + uint32_t seqno; + + reasm_queue_ent_t(shm::MsgBuf* m, uint32_t s) : msgbuf(m), seqno(s) {} + }; + + static constexpr std::size_t kReassemblyMaxSeqnoDistance = 256; + + RXTracking(const RXTracking&) = delete; + RXTracking(uint32_t local_ip, uint16_t local_port, uint32_t remote_ip, + uint16_t remote_port, shm::Channel* channel) + : local_ip_(local_ip), + local_port_(local_port), + remote_ip_(remote_ip), + remote_port_(remote_port), + channel_(CHECK_NOTNULL(channel)), + cur_msg_train_head_(nullptr), + cur_msg_train_tail_(nullptr) {} + + /** + * @brief Consume a TCP data packet. + * @param pcb TCP control block. + * @param packet Pointer to the received packet. + * @param payload_offset Offset to the payload in the packet. + * @param payload_len Length of the payload. + * @return 0 on success, -1 on failure. + */ + int Consume(tcp::TcpControlBlock* pcb, const dpdk::Packet* packet, + size_t payload_offset, size_t payload_len) { + const auto* tcph = packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + const auto* payload = packet->head_data(payload_offset); + const auto seqno = tcph->seq_num.value(); + const auto expected_seqno = pcb->rcv_nxt; + + if (tcp::seqno_lt(seqno, expected_seqno)) { + VLOG(2) << "Received old packet: " << seqno << " < " << expected_seqno; + return 0; + } + + const size_t distance = seqno - expected_seqno; + if (distance >= kReassemblyMaxSeqnoDistance) { + LOG(ERROR) << "Packet too far ahead. seqno: " << seqno + << ", expected: " << expected_seqno; + return 0; + } + + // Check for duplicates in out-of-order case + auto it = reass_q_.begin(); + if (seqno != expected_seqno) { + it = std::find_if(reass_q_.begin(), reass_q_.end(), + [&seqno](const reasm_queue_ent_t& entry) { + return entry.seqno >= seqno; + }); + if (it != reass_q_.end() && it->seqno == seqno) { + return 0; // Duplicate packet + } + } + + // Allocate buffer in SHM channel + auto* msgbuf = channel_->MsgBufAlloc(); + if (msgbuf == nullptr) { + VLOG(1) << "Failed to allocate message buffer. Dropping packet."; + return -1; + } + + auto* msg_data = msgbuf->append(payload_len); + utils::Copy(CHECK_NOTNULL(msg_data), payload, payload_len); + msgbuf->set_src_ip(remote_ip_); + msgbuf->set_src_port(remote_port_); + msgbuf->set_dst_ip(local_ip_); + msgbuf->set_dst_port(local_port_); + + if (seqno == expected_seqno) { + reass_q_.emplace_front(msgbuf, seqno); + } else { + reass_q_.insert(it, reasm_queue_ent_t(msgbuf, seqno)); + } + + // Add SACK block for out-of-order packet + if (seqno != expected_seqno) { + pcb->AddSackBlock(seqno, seqno + payload_len); + } + + PushInOrderMsgbufsToShmTrain(pcb, payload_len); + return 0; + } + + private: + void PushInOrderMsgbufsToShmTrain(tcp::TcpControlBlock* pcb, + size_t payload_len) { + while (!reass_q_.empty() && reass_q_.front().seqno == pcb->rcv_nxt) { + auto& front = reass_q_.front(); + auto* msgbuf = front.msgbuf; + reass_q_.pop_front(); + + if (cur_msg_train_head_ == nullptr) { + DCHECK(msgbuf->is_first()); + cur_msg_train_head_ = msgbuf; + cur_msg_train_tail_ = msgbuf; + } else { + cur_msg_train_tail_->set_next(msgbuf); + cur_msg_train_tail_ = msgbuf; + } + + if (cur_msg_train_tail_->is_last()) { + // Complete message, deliver to application + DCHECK(!cur_msg_train_tail_->is_sg()); + auto* msgbuf_to_deliver = cur_msg_train_head_; + auto nr_delivered = channel_->EnqueueMessages(&msgbuf_to_deliver, 1); + if (nr_delivered != 1) { + LOG(FATAL) << "SHM channel full, failed to deliver message"; + } + + cur_msg_train_head_ = nullptr; + cur_msg_train_tail_ = nullptr; + } + + // Advance rcv_nxt by payload length + pcb->rcv_nxt += payload_len; + pcb->ClearSackBlocks(); + } + } + + const uint32_t local_ip_; + const uint16_t local_port_; + const uint32_t remote_ip_; + const uint16_t remote_port_; + shm::Channel* channel_; + std::deque reass_q_; + shm::MsgBuf* cur_msg_train_head_; + shm::MsgBuf* cur_msg_train_tail_; +}; + +/** + * @class TcpFlow + * @brief A TCP flow representing a connection between local and remote + * endpoints. + * + * Manages TCP connection state, congestion control, and data transfer + * using the TCP protocol instead of UDP. + */ +class TcpFlow { + public: + using Ethernet = net::Ethernet; + using Ipv4 = net::Ipv4; + using Tcp = net::Tcp; + using ApplicationCallback = + std::function; + + /** + * @brief Construct a new TCP flow. + * + * @param local_addr Local IP address. + * @param local_port Local TCP port. + * @param remote_addr Remote IP address. + * @param remote_port Remote TCP port. + * @param local_l2_addr Local L2 address. + * @param remote_l2_addr Remote L2 address. + * @param txring TX ring to send packets to. + * @param channel Shared memory channel this flow is associated with. + */ + TcpFlow(const Ipv4::Address& local_addr, const Tcp::Port& local_port, + const Ipv4::Address& remote_addr, const Tcp::Port& remote_port, + const Ethernet::Address& local_l2_addr, + const Ethernet::Address& remote_l2_addr, dpdk::TxRing* txring, + ApplicationCallback callback, shm::Channel* channel) + : key_(local_addr, local_port, remote_addr, remote_port), + local_l2_addr_(local_l2_addr), + remote_l2_addr_(remote_l2_addr), + state_(tcp::TcpState::kClosed), + txring_(CHECK_NOTNULL(txring)), + callback_(std::move(callback)), + channel_(CHECK_NOTNULL(channel)), + pcb_(), + tx_tracking_(CHECK_NOTNULL(channel)), + rx_tracking_(local_addr.address.value(), local_port.port.value(), + remote_addr.address.value(), remote_port.port.value(), + CHECK_NOTNULL(channel)) { + CHECK_NOTNULL(txring_->GetPacketPool()); + // Initialize PCB with random ISN + pcb_.Initialize(GenerateISN()); + } + + ~TcpFlow() = default; + + bool operator==(const TcpFlow& other) const { return key_ == other.key(); } + + const Key& key() const { return key_; } + shm::Channel* channel() const { return channel_; } + tcp::TcpState state() const { return state_; } + + std::string ToString() const { + return utils::Format( + "%s [%s] <-> [%s]\n\t\t\t%s\n\t\t\t[TX Queue] Pending MsgBufs: %u", + key_.ToString().c_str(), tcp::TcpStateToString(state_), + channel_->GetName().c_str(), pcb_.ToString().c_str(), + tx_tracking_.NumUnsentMsgbufs()); + } + + bool Match(const dpdk::Packet* packet) const { + const auto* ih = packet->head_data(sizeof(Ethernet)); + const auto* tcph = + packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + + return (ih->src_addr == key_.remote_addr && + ih->dst_addr == key_.local_addr && + tcph->src_port == key_.remote_port && + tcph->dst_port == key_.local_port); + } + + bool Match(const shm::MsgBuf* tx_msgbuf) const { + const auto* flow_info = tx_msgbuf->flow(); + return (flow_info->src_ip == key_.local_addr.address.value() && + flow_info->dst_ip == key_.remote_addr.address.value() && + flow_info->src_port == key_.local_port.port.value() && + flow_info->dst_port == key_.remote_port.port.value()); + } + + /** + * @brief Initiate TCP three-way handshake (active open). + */ + void InitiateHandshake() { + CHECK(state_ == tcp::TcpState::kClosed); + SendSyn(); + state_ = tcp::TcpState::kSynSent; + pcb_.RtoEnable(); + } + + /** + * @brief Shutdown the connection. + */ + void ShutDown() { + switch (state_) { + case tcp::TcpState::kClosed: + break; + case tcp::TcpState::kEstablished: + SendFin(); + state_ = tcp::TcpState::kFinWait1; + break; + case tcp::TcpState::kCloseWait: + SendFin(); + state_ = tcp::TcpState::kLastAck; + break; + default: + // Send RST and close + SendRst(); + state_ = tcp::TcpState::kClosed; + break; + } + pcb_.RtoDisable(); + } + + /** + * @brief Process an incoming TCP packet. + * @param packet Pointer to the received packet. + */ + void InputPacket(const dpdk::Packet* packet) { + const auto* tcph = + packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + const uint8_t tcp_flags = tcph->GetFlags(); + const uint32_t seq = tcph->seq_num.value(); + const uint32_t ack = tcph->ack_num.value(); + const uint16_t window = tcph->window.value(); + + // Update send window + pcb_.UpdateSndWnd(window); + + // Handle RST + if (tcp_flags & Tcp::kRst) { + HandleRst(seq); + return; + } + + switch (state_) { + case tcp::TcpState::kSynSent: + HandleSynSent(tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kSynReceived: + HandleSynReceived(tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kEstablished: + HandleEstablished(packet, tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kFinWait1: + HandleFinWait1(tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kFinWait2: + HandleFinWait2(tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kCloseWait: + HandleCloseWait(tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kClosing: + HandleClosing(tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kLastAck: + HandleLastAck(tcph, tcp_flags, seq, ack); + break; + case tcp::TcpState::kTimeWait: + // In TIME_WAIT, respond to any valid segment with ACK + if (tcp_flags & Tcp::kAck) SendAck(); + break; + default: + break; + } + } + + /** + * @brief Push a message from the application to the egress queue. + * @param msg Pointer to the first message buffer. + */ + void OutputMessage(shm::MsgBuf* msg) { + tx_tracking_.Append(msg); + TransmitPackets(); + } + + /** + * @brief Periodic check for timeouts and retransmissions. + * @return True if flow should continue, false if should be removed. + */ + bool PeriodicCheck() { + if (state_ == tcp::TcpState::kClosed) return false; + + if (pcb_.RtoDisabled()) return true; + + pcb_.RtoAdvance(); + if (pcb_.MaxRexmitsReached()) { + if (state_ == tcp::TcpState::kSynSent) { + LOG(INFO) << "TCP Flow " << this << " failed to establish"; + callback_(channel(), false, key()); + } + return false; + } + + if (pcb_.RtoExpired()) { + RTORetransmit(); + } + + // Check for fast retransmit + if (pcb_.ShouldFastRetransmit()) { + FastRetransmit(); + } + + return true; + } + + private: + /** + * @brief Generate Initial Sequence Number (ISN). + */ + static uint32_t GenerateISN() { + // Simple ISN generation - in production use RFC 6528 + return static_cast( + std::chrono::steady_clock::now().time_since_epoch().count() & 0xFFFFFFFF); + } + + void PrepareL2Header(dpdk::Packet* packet) const { + auto* eh = packet->head_data(); + eh->src_addr = local_l2_addr_; + eh->dst_addr = remote_l2_addr_; + eh->eth_type = be16_t(Ethernet::kIpv4); + packet->set_l2_len(sizeof(*eh)); + } + + void PrepareL3Header(dpdk::Packet* packet) const { + auto* ipv4h = packet->head_data(sizeof(Ethernet)); + ipv4h->version_ihl = 0x45; + ipv4h->type_of_service = 0; + ipv4h->packet_id = be16_t(0x1513); + ipv4h->fragment_offset = be16_t(0); + ipv4h->time_to_live = 64; + ipv4h->next_proto_id = Ipv4::Proto::kTcp; + ipv4h->total_length = be16_t(packet->length() - sizeof(Ethernet)); + ipv4h->src_addr = key_.local_addr; + ipv4h->dst_addr = key_.remote_addr; + ipv4h->hdr_checksum = 0; + packet->set_l3_len(sizeof(*ipv4h)); + } + + void PrepareL4Header(dpdk::Packet* packet, uint32_t seq, uint32_t ack, + uint8_t flags) const { + auto* tcph = packet->head_data(sizeof(Ethernet) + sizeof(Ipv4)); + tcph->src_port = key_.local_port; + tcph->dst_port = key_.remote_port; + tcph->seq_num = be32_t(seq); + tcph->ack_num = be32_t(ack); + tcph->SetDataOffset(5); // 20 bytes, no options + tcph->SetFlags(flags); + tcph->window = be16_t(pcb_.rcv_wnd); + tcph->cksum = be16_t(0); + tcph->urgent_ptr = be16_t(0); + packet->offload_tcpv4_csum(); + } + + void SendControlPacket(uint32_t seq, uint32_t ack, uint8_t flags) const { + auto* packet = CHECK_NOTNULL(txring_->GetPacketPool()->PacketAlloc()); + dpdk::Packet::Reset(packet); + + const size_t kControlPacketSize = + sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp); + CHECK_NOTNULL(packet->append(kControlPacketSize)); + PrepareL2Header(packet); + PrepareL3Header(packet); + PrepareL4Header(packet, seq, ack, flags); + + txring_->SendPackets(&packet, 1); + } + + void SendSyn() { + uint32_t seq = pcb_.GetAndAdvanceSndNxt(); + SendControlPacket(seq, 0, Tcp::kSyn); + } + + void SendSynAck() { + uint32_t seq = pcb_.GetAndAdvanceSndNxt(); + SendControlPacket(seq, pcb_.rcv_nxt, Tcp::kSyn | Tcp::kAck); + } + + void SendAck() const { + SendControlPacket(pcb_.snd_nxt, pcb_.rcv_nxt, Tcp::kAck); + } + + void SendFin() { + uint32_t seq = pcb_.GetAndAdvanceSndNxt(); + SendControlPacket(seq, pcb_.rcv_nxt, Tcp::kFin | Tcp::kAck); + } + + void SendRst() const { + SendControlPacket(pcb_.snd_nxt, pcb_.rcv_nxt, Tcp::kRst); + } + + void HandleRst(uint32_t seq) { + if (tcp::seqno_eq(seq, pcb_.rcv_nxt)) { + state_ = tcp::TcpState::kClosed; + callback_(channel(), false, key()); + } + } + + void HandleSynSent(const Tcp* tcph, uint8_t flags, uint32_t seq, + uint32_t ack) { + if ((flags & (Tcp::kSyn | Tcp::kAck)) == (Tcp::kSyn | Tcp::kAck)) { + // SYN-ACK received + if (ack != pcb_.snd_nxt) { + LOG(ERROR) << "Invalid SYN-ACK ack number"; + SendRst(); + return; + } + pcb_.snd_una = ack; + pcb_.rcv_nxt = seq + 1; + pcb_.irs = seq; + SendAck(); + state_ = tcp::TcpState::kEstablished; + pcb_.RtoReset(); + callback_(channel(), true, key()); + } else if (flags & Tcp::kSyn) { + // Simultaneous open - SYN received + pcb_.rcv_nxt = seq + 1; + pcb_.irs = seq; + SendSynAck(); + state_ = tcp::TcpState::kSynReceived; + } + } + + void HandleSynReceived(const Tcp* tcph, uint8_t flags, uint32_t seq, + uint32_t ack) { + if (flags & Tcp::kAck) { + if (ack == pcb_.snd_nxt) { + pcb_.snd_una = ack; + state_ = tcp::TcpState::kEstablished; + pcb_.RtoReset(); + callback_(channel(), true, key()); + } + } + } + + void HandleEstablished(const dpdk::Packet* packet, const Tcp* tcph, + uint8_t flags, uint32_t seq, uint32_t ack) { + // Process ACK + if (flags & Tcp::kAck) { + bool is_dup = tcp::seqno_le(ack, pcb_.snd_una); + pcb_.ProcessAck(ack, is_dup); + } + + // Process data + size_t hdr_len = sizeof(Ethernet) + sizeof(Ipv4) + tcph->GetHeaderLength(); + size_t payload_len = packet->length() - hdr_len; + if (payload_len > 0) { + if (tcp::seqno_eq(seq, pcb_.rcv_nxt)) { + rx_tracking_.Consume(&pcb_, packet, hdr_len, payload_len); + SendAck(); + } else if (tcp::seqno_gt(seq, pcb_.rcv_nxt)) { + // Out of order, buffer and send duplicate ACK + rx_tracking_.Consume(&pcb_, packet, hdr_len, payload_len); + SendAck(); + } + } + + // Handle FIN + if (flags & Tcp::kFin) { + pcb_.rcv_nxt = seq + 1; + SendAck(); + state_ = tcp::TcpState::kCloseWait; + } + + // Try to send more data + TransmitPackets(); + } + + void HandleFinWait1(const Tcp* tcph, uint8_t flags, uint32_t seq, + uint32_t ack) { + if (flags & Tcp::kAck) { + pcb_.ProcessAck(ack, false); + if (flags & Tcp::kFin) { + pcb_.rcv_nxt = seq + 1; + SendAck(); + state_ = tcp::TcpState::kTimeWait; + } else { + state_ = tcp::TcpState::kFinWait2; + } + } else if (flags & Tcp::kFin) { + pcb_.rcv_nxt = seq + 1; + SendAck(); + state_ = tcp::TcpState::kClosing; + } + } + + void HandleFinWait2(const Tcp* tcph, uint8_t flags, uint32_t seq, + uint32_t ack) { + if (flags & Tcp::kFin) { + pcb_.rcv_nxt = seq + 1; + SendAck(); + state_ = tcp::TcpState::kTimeWait; + } + } + + void HandleCloseWait(const Tcp* tcph, uint8_t flags, uint32_t seq, + uint32_t ack) { + if (flags & Tcp::kAck) { + pcb_.ProcessAck(ack, false); + } + } + + void HandleClosing(const Tcp* tcph, uint8_t flags, uint32_t seq, + uint32_t ack) { + if (flags & Tcp::kAck) { + pcb_.ProcessAck(ack, false); + state_ = tcp::TcpState::kTimeWait; + } + } + + void HandleLastAck(const Tcp* tcph, uint8_t flags, uint32_t seq, + uint32_t ack) { + if (flags & Tcp::kAck) { + state_ = tcp::TcpState::kClosed; + } + } + + void FastRetransmit() { + auto* msgbuf = tx_tracking_.GetOldestUnackedMsgBuf(); + if (msgbuf == nullptr) return; + + auto* packet = CHECK_NOTNULL(txring_->GetPacketPool()->PacketAlloc()); + PrepareDataPacket(msgbuf, packet, pcb_.snd_una); + txring_->SendPackets(&packet, 1); + pcb_.RtoReset(); + LOG(INFO) << "Fast retransmitting TCP packet " << pcb_.snd_una; + } + + void RTORetransmit() { + pcb_.OnTimeout(); + + if (state_ == tcp::TcpState::kEstablished) { + auto* msgbuf = tx_tracking_.GetOldestUnackedMsgBuf(); + if (msgbuf != nullptr) { + auto* packet = CHECK_NOTNULL(txring_->GetPacketPool()->PacketAlloc()); + PrepareDataPacket(msgbuf, packet, pcb_.snd_una); + txring_->SendPackets(&packet, 1); + LOG(INFO) << "RTO retransmitting TCP data packet " << pcb_.snd_una; + } + } else if (state_ == tcp::TcpState::kSynSent) { + SendSyn(); + LOG(INFO) << "RTO retransmitting SYN packet"; + } else if (state_ == tcp::TcpState::kSynReceived) { + SendSynAck(); + LOG(INFO) << "RTO retransmitting SYN-ACK packet"; + } + pcb_.RtoReset(); + } + + void PrepareDataPacket(shm::MsgBuf* msg_buf, dpdk::Packet* packet, + uint32_t seqno) const { + const size_t hdr_length = sizeof(Ethernet) + sizeof(Ipv4) + sizeof(Tcp); + const uint32_t pkt_len = hdr_length + msg_buf->length(); + CHECK_LE(pkt_len - sizeof(Ethernet), dpdk::PmdRing::kDefaultFrameSize); + + dpdk::Packet::Reset(packet); + CHECK_NOTNULL(packet->append(pkt_len)); + + PrepareL2Header(packet); + PrepareL3Header(packet); + PrepareL4Header(packet, seqno, pcb_.rcv_nxt, Tcp::kAck | Tcp::kPsh); + + // Copy payload + auto* payload = + packet->head_data(sizeof(Ethernet) + sizeof(Ipv4) + + sizeof(Tcp)); + utils::Copy(payload, msg_buf->head_data(), msg_buf->length()); + } + + void TransmitPackets() { + auto effective_wnd = pcb_.EffectiveWindow() / pcb_.mss; + auto remaining_packets = + std::min(effective_wnd, tx_tracking_.NumUnsentMsgbufs()); + if (remaining_packets == 0) return; + + do { + dpdk::PacketBatch batch; + auto pkt_cnt = + std::min(remaining_packets, static_cast(batch.GetRoom())); + if (!txring_->GetPacketPool()->PacketBulkAlloc(&batch, pkt_cnt)) { + LOG(ERROR) << "Failed to allocate packet batch"; + return; + } + + for (uint16_t i = 0; i < batch.GetSize(); i++) { + auto msg = tx_tracking_.GetAndUpdateOldestUnsent(); + if (!msg.has_value()) break; + auto* msg_buf = msg.value(); + auto* packet = batch.pkts()[i]; + PrepareDataPacket(msg_buf, packet, pcb_.GetAndAdvanceSndNxt(msg_buf->length())); + } + + txring_->SendPackets(&batch); + remaining_packets -= pkt_cnt; + } while (remaining_packets); + + if (pcb_.RtoDisabled()) pcb_.RtoEnable(); + } + + const Key key_; + const Ethernet::Address local_l2_addr_; + const Ethernet::Address remote_l2_addr_; + tcp::TcpState state_; + dpdk::TxRing* txring_; + ApplicationCallback callback_; + shm::Channel* channel_; + tcp::TcpControlBlock pcb_; + TXTracking tx_tracking_; + RXTracking rx_tracking_; +}; + +} // namespace tcp_flow +} // namespace net +} // namespace juggler + +namespace std { + +template <> +struct hash { + size_t operator()(const juggler::net::tcp_flow::TcpFlow& flow) const { + const auto& key = flow.key(); + return juggler::utils::hash(reinterpret_cast(&key), + sizeof(key)); + } +}; + +} // namespace std + +#endif // SRC_INCLUDE_TCP_FLOW_H_