Skip to content

Commit

Permalink
Merge pull request #68 from wolfendale/timeout-rework
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbarrow authored Jul 2, 2024
2 parents 2b6ea02 + 0502cc6 commit 293c20d
Show file tree
Hide file tree
Showing 12 changed files with 344 additions and 188 deletions.
16 changes: 16 additions & 0 deletions mutex_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,22 @@ func (m *MutexMap[K, V]) Get(key K) (V, bool) {
return value, ok
}

// GetOrSetDefault returns the value for the given key if it exists. If it does not exist, it creates a default
// with the provided function and sets it for that key
func (m *MutexMap[K, V]) GetOrSetDefault(key K, createDefault func() V) V {
m.Lock()
defer m.Unlock()

value, ok := m.real[key]

if !ok {
value = createDefault()
m.real[key] = value
}

return value
}

// Has checks if a key exists in the map
func (m *MutexMap[K, V]) Has(key K) bool {
m.RLock()
Expand Down
6 changes: 5 additions & 1 deletion prudp_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ type PRUDPConnection struct {
Signature []byte // * Connection signature for packets coming from the client, as seen by the server
ServerConnectionSignature []byte // * Connection signature for packets coming from the server, as seen by the client
UnreliablePacketBaseKey []byte // * The base key used for encrypting unreliable DATA packets
rtt *RTT // * The round-trip transmission time of this connection
slidingWindows *MutexMap[uint8, *SlidingWindow] // * Outbound reliable packet substreams
packetDispatchQueues *MutexMap[uint8, *PacketDispatchQueue] // * Inbound reliable packet substreams
incomingFragmentBuffers *MutexMap[uint8, []byte] // * Buffers which store the incoming payloads from fragmented DATA packets
outgoingUnreliableSequenceIDCounter *Counter[uint16]
outgoingPingSequenceIDCounter *Counter[uint16]
lastSentPingTime time.Time
heartbeatTimer *time.Timer
pingKickTimer *time.Timer
StationURLs *types.List[*types.StationURL]
Expand Down Expand Up @@ -62,12 +64,13 @@ func (pc *PRUDPConnection) SetPID(pid *types.PID) {

// reset resets the connection state to all zero values
func (pc *PRUDPConnection) reset() {
pc.ConnectionState = StateNotConnected
pc.packetDispatchQueues.Clear(func(_ uint8, packetDispatchQueue *PacketDispatchQueue) {
packetDispatchQueue.Purge()
})

pc.slidingWindows.Clear(func(_ uint8, slidingWindow *SlidingWindow) {
slidingWindow.ResendScheduler.Stop()
slidingWindow.TimeoutManager.Stop()
})

pc.Signature = make([]byte, 0)
Expand Down Expand Up @@ -289,6 +292,7 @@ func NewPRUDPConnection(socket *SocketConnection) *PRUDPConnection {
pc := &PRUDPConnection{
Socket: socket,
ConnectionState: StateNotConnected,
rtt: NewRTT(),
pid: types.NewPID(0),
slidingWindows: NewMutexMap[uint8, *SlidingWindow](),
packetDispatchQueues: NewMutexMap[uint8, *PacketDispatchQueue](),
Expand Down
97 changes: 70 additions & 27 deletions prudp_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,25 @@ import (
// and secure servers. However the functionality of rdv::PRUDPEndPoint and nn::nex::SecureEndPoint is seemingly
// identical. Rather than duplicate the logic from PRUDPEndpoint, a IsSecureEndpoint flag has been added instead.
type PRUDPEndPoint struct {
Server *PRUDPServer
StreamID uint8
DefaultStreamSettings *StreamSettings
Connections *MutexMap[string, *PRUDPConnection]
packetHandlers map[uint16]func(packet PRUDPPacketInterface)
packetEventHandlers map[string][]func(packet PacketInterface)
connectionEndedEventHandlers []func(connection *PRUDPConnection)
errorEventHandlers []func(err *Error)
ConnectionIDCounter *Counter[uint32]
ServerAccount *Account
AccountDetailsByPID func(pid *types.PID) (*Account, *Error)
AccountDetailsByUsername func(username string) (*Account, *Error)
IsSecureEndPoint bool
}
Server *PRUDPServer
StreamID uint8
DefaultStreamSettings *StreamSettings
Connections *MutexMap[string, *PRUDPConnection]
packetHandlers map[uint16]func(packet PRUDPPacketInterface)
packetEventHandlers map[string][]func(packet PacketInterface)
connectionEndedEventHandlers []func(connection *PRUDPConnection)
errorEventHandlers []func(err *Error)
ConnectionIDCounter *Counter[uint32]
ServerAccount *Account
AccountDetailsByPID func(pid *types.PID) (*Account, *Error)
AccountDetailsByUsername func(username string) (*Account, *Error)
IsSecureEndPoint bool
CalcRetransmissionTimeoutCallback CalcRetransmissionTimeoutCallback
}

// CalcRetransmissionTimeoutCallback is an optional callback which can be used to override the RTO calculation
// for packets sent by this `PRUDPEndpoint`
type CalcRetransmissionTimeoutCallback func(rtt float64, sendCount uint32) time.Duration

// RegisterServiceProtocol registers a NEX service with the endpoint
func (pep *PRUDPEndPoint) RegisterServiceProtocol(protocol ServiceProtocol) {
Expand Down Expand Up @@ -111,19 +116,19 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc
streamType := packet.SourceVirtualPortStreamType()
streamID := packet.SourceVirtualPortStreamID()
discriminator := fmt.Sprintf("%s-%d-%d", socket.Address.String(), streamType, streamID)
connection, ok := pep.Connections.Get(discriminator)

if !ok {
connection = NewPRUDPConnection(socket)
connection := pep.Connections.GetOrSetDefault(discriminator, func() *PRUDPConnection {
connection := NewPRUDPConnection(socket)
connection.endpoint = pep
connection.ID = pep.ConnectionIDCounter.Next()
connection.DefaultPRUDPVersion = packet.Version()
connection.StreamType = streamType
connection.StreamID = streamID
connection.StreamSettings = pep.DefaultStreamSettings.Copy()
return connection
})

pep.Connections.Set(discriminator, connection)
}
connection.Lock()
defer connection.Unlock()

packet.SetSender(connection)

Expand Down Expand Up @@ -153,8 +158,14 @@ func (pep *PRUDPEndPoint) handleAcknowledgment(packet PRUDPPacketInterface) {
return
}

slidingWindow := connection.SlidingWindow(packet.SubstreamID())
slidingWindow.ResendScheduler.AcknowledgePacket(packet.SequenceID())
if packet.Type() == constants.PingPacket {
if packet.SequenceID() == connection.outgoingPingSequenceIDCounter.Value {
connection.rtt.Adjust(time.Since(connection.lastSentPingTime))
}
} else {
slidingWindow := connection.SlidingWindow(packet.SubstreamID())
slidingWindow.TimeoutManager.AcknowledgePacket(packet.SequenceID())
}
}

func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface) {
Expand Down Expand Up @@ -191,7 +202,7 @@ func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface)

// * MutexMap.Each locks the mutex, can't remove while reading.
// * Have to just loop again
slidingWindow.ResendScheduler.packets.Each(func(sequenceID uint16, pending *PendingPacket) bool {
slidingWindow.TimeoutManager.packets.Each(func(sequenceID uint16, pending PRUDPPacketInterface) bool {
if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) {
sequenceIDs = append(sequenceIDs, sequenceID)
}
Expand All @@ -201,7 +212,7 @@ func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface)

// * Actually remove the packets from the pool
for _, sequenceID := range sequenceIDs {
slidingWindow.ResendScheduler.AcknowledgePacket(sequenceID)
slidingWindow.TimeoutManager.AcknowledgePacket(sequenceID)
}
}

Expand Down Expand Up @@ -397,7 +408,6 @@ func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) {

func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) {
// TODO - Should we check the state here, or just let the connection disconnect at any time?
// TODO - Should we bother to set the connections state here? It's being destroyed anyway

if packet.HasFlag(constants.PacketFlagNeedsAck) {
pep.acknowledgePacket(packet)
Expand All @@ -407,6 +417,8 @@ func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) {
streamID := packet.SourceVirtualPortStreamID()
discriminator := fmt.Sprintf("%s-%d-%d", packet.Sender().Address().String(), streamType, streamID)
if connection, ok := pep.Connections.Get(discriminator); ok {
// * We make sure to update the connection state here because we could still be attempting to
// * resend packets.
connection.cleanup()
pep.Connections.Delete(discriminator)
}
Expand Down Expand Up @@ -539,8 +551,6 @@ func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) {
}

connection := packet.Sender().(*PRUDPConnection)
connection.Lock()
defer connection.Unlock()

substreamID := packet.SubstreamID()

Expand Down Expand Up @@ -702,6 +712,39 @@ func (pep *PRUDPEndPoint) FindConnectionByPID(pid uint64) *PRUDPConnection {
return connection
}

// ComputeRetransmitTimeout computes the RTO (Retransmit timeout) for a given packet
func (pep *PRUDPEndPoint) ComputeRetransmitTimeout(packet PRUDPPacketInterface) time.Duration {
connection := packet.Sender().(*PRUDPConnection)
rtt := connection.rtt

if callback := pep.CalcRetransmissionTimeoutCallback; callback != nil {
rttAverage := rtt.GetRTTSmoothedAvg()
rttDeviation := rtt.GetRTTSmoothedDev()
return callback(rttAverage+rttDeviation*4.0, packet.SendCount())
}

var retransmitTimeBase int64
if packet.Type() == constants.SynPacket {
retransmitTimeBase = int64(pep.DefaultStreamSettings.SynInitialRTT)
} else {
retransmitTimeBase = int64(pep.DefaultStreamSettings.InitialRTT)
if rtt.Initialized() {
retransmitTimeBase = int64(rtt.Average()/time.Millisecond) / 8
}
}

retransmitTimeBaseMultiplier := packet.SendCount()

var retransmitMultiplier float64
if packet.SendCount() < pep.DefaultStreamSettings.ExtraRetransmitTimeoutTrigger {
retransmitMultiplier = float64(pep.DefaultStreamSettings.RetransmitTimeoutMultiplier)
} else {
retransmitMultiplier = float64(pep.DefaultStreamSettings.ExtraRetransmitTimeoutMultiplier)
}

return time.Duration(float64(retransmitTimeBase*int64(retransmitTimeBaseMultiplier))*retransmitMultiplier) * time.Millisecond
}

// AccessKey returns the servers sandbox access key
func (pep *PRUDPEndPoint) AccessKey() string {
return pep.Server.AccessKey
Expand Down
30 changes: 30 additions & 0 deletions prudp_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nex

import (
"crypto/rc4"
"time"

"github.com/PretendoNetwork/nex-go/v2/constants"
)
Expand All @@ -24,6 +25,9 @@ type PRUDPPacket struct {
fragmentID uint8
payload []byte
message *RMCMessage
sendCount uint32
sentAt time.Time
timeout *Timeout
}

// SetSender sets the Client who sent the packet
Expand Down Expand Up @@ -184,6 +188,32 @@ func (p *PRUDPPacket) SetRMCMessage(message *RMCMessage) {
p.message = message
}

// SendCount returns the number of times this packet has been sent
func (p *PRUDPPacket) SendCount() uint32 {
return p.sendCount
}

func (p *PRUDPPacket) incrementSendCount() {
p.sendCount++
}

// SentAt returns the latest time that this packet has been sent
func (p *PRUDPPacket) SentAt() time.Time {
return p.sentAt
}

func (p *PRUDPPacket) setSentAt(time time.Time) {
p.sentAt = time
}

func (p *PRUDPPacket) getTimeout() *Timeout {
return p.timeout
}

func (p *PRUDPPacket) setTimeout(timeout *Timeout) {
p.timeout = timeout
}

func (p *PRUDPPacket) processUnreliableCrypto() []byte {
// * Since unreliable DATA packets can come in out of
// * order, each packet uses a dedicated RC4 stream
Expand Down
7 changes: 7 additions & 0 deletions prudp_packet_interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package nex

import (
"net"
"time"

"github.com/PretendoNetwork/nex-go/v2/constants"
)
Expand Down Expand Up @@ -36,6 +37,12 @@ type PRUDPPacketInterface interface {
SetPayload(payload []byte)
RMCMessage() *RMCMessage
SetRMCMessage(message *RMCMessage)
SendCount() uint32
incrementSendCount()
SentAt() time.Time
setSentAt(time time.Time)
getTimeout() *Timeout
setTimeout(timeout *Timeout)
decode() error
setSignature(signature []byte)
calculateConnectionSignature(addr net.Addr) ([]byte, error)
Expand Down
6 changes: 5 additions & 1 deletion prudp_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) {
packetCopy.SetSequenceID(connection.outgoingUnreliableSequenceIDCounter.Next())
} else if packetCopy.Type() == constants.PingPacket {
packetCopy.SetSequenceID(connection.outgoingPingSequenceIDCounter.Next())
connection.lastSentPingTime = time.Now()
} else {
packetCopy.SetSequenceID(0)
}
Expand Down Expand Up @@ -288,9 +289,12 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) {
packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature))
}

packetCopy.incrementSendCount()
packetCopy.setSentAt(time.Now())

if packetCopy.HasFlag(constants.PacketFlagReliable) && packetCopy.HasFlag(constants.PacketFlagNeedsAck) {
slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID())
slidingWindow.ResendScheduler.AddPacket(packetCopy)
slidingWindow.TimeoutManager.SchedulePacketTimeout(packetCopy)
}

ps.sendRaw(packetCopy.Sender().(*PRUDPConnection).Socket, packetCopy.Bytes())
Expand Down
Loading

0 comments on commit 293c20d

Please sign in to comment.