Skip to content

Commit 4f6861b

Browse files
authored
Merge pull request #413 from Hyperloop-UPV/backend/transport_bench
Backend, Transport Module Refactor
2 parents e91961d + 23918ae commit 4f6861b

File tree

7 files changed

+604
-62
lines changed

7 files changed

+604
-62
lines changed

backend/pkg/transport/constructor.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212

1313
func NewTransport(baseLogger zerolog.Logger) *Transport {
1414
transport := &Transport{
15-
connectionsMx: &sync.Mutex{},
15+
connectionsMx: &sync.RWMutex{},
1616
connections: make(map[abstraction.TransportTarget]net.Conn),
1717
idToTarget: make(map[abstraction.PacketId]abstraction.TransportTarget),
1818
ipToTarget: make(map[string]abstraction.TransportTarget),

backend/pkg/transport/packet/data/decoder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ func (decoder *Decoder) Decode(id abstraction.PacketId, reader io.Reader) (abstr
3535
return nil, ErrUnexpectedId{Id: id}
3636
}
3737

38-
packet := NewPacket(id)
38+
packet := GetPacket(id)
3939
for _, value := range descriptor {
4040
val, err := value.Decode(decoder.endianness, reader)
4141
if err != nil {

backend/pkg/transport/packet/data/packet.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package data
22

33
import (
4+
"sync"
45
"time"
56

67
"github.com/HyperloopUPV-H8/h9-backend/pkg/abstraction"
@@ -27,6 +28,16 @@ func NewPacket(id abstraction.PacketId) *Packet {
2728
}
2829
}
2930

31+
var packetPool = sync.Pool{
32+
New: func() any {
33+
return &Packet{
34+
values: make(map[ValueName]Value),
35+
enabled: make(map[ValueName]bool),
36+
}
37+
},
38+
}
39+
40+
3041
// NewPacketWithValues creates a new data packet with the given values
3142
func NewPacketWithValues(id abstraction.PacketId, values map[ValueName]Value, enabled map[ValueName]bool) *Packet {
3243
return &Packet{
@@ -62,3 +73,35 @@ func (packet *Packet) SetTimestamp(timestamp time.Time) *Packet {
6273
packet.timestamp = timestamp
6374
return packet
6475
}
76+
77+
func (packet *Packet) Reset() {
78+
clear(packet.values)
79+
clear(packet.enabled)
80+
packet.id = 0
81+
packet.timestamp = time.Time{}
82+
}
83+
84+
func GetPacket(id abstraction.PacketId) *Packet {
85+
p := packetPool.Get().(*Packet)
86+
if p.values == nil {
87+
p.values = make(map[ValueName]Value)
88+
} else {
89+
clear(p.values)
90+
}
91+
if p.enabled == nil {
92+
p.enabled = make(map[ValueName]bool)
93+
} else {
94+
clear(p.enabled)
95+
}
96+
p.id = id
97+
p.timestamp = time.Now()
98+
return p
99+
}
100+
101+
func ReleasePacket(p *Packet) {
102+
if p == nil {
103+
return
104+
}
105+
p.Reset()
106+
packetPool.Put(p)
107+
}

backend/pkg/transport/presentation/encoder.go

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/binary"
66
"io"
7+
"sync"
78

89
"github.com/HyperloopUPV-H8/h9-backend/pkg/abstraction"
910
"github.com/rs/zerolog"
@@ -17,7 +18,8 @@ type Encoder struct {
1718
idToEncoder map[abstraction.PacketId]PacketEncoder
1819
endianness binary.ByteOrder
1920

20-
logger zerolog.Logger
21+
logger zerolog.Logger
22+
bufPool sync.Pool
2123
}
2224

2325
// TODO: improve constructor
@@ -28,6 +30,9 @@ func NewEncoder(endianness binary.ByteOrder, baseLogger zerolog.Logger) *Encoder
2830
endianness: endianness,
2931

3032
logger: baseLogger,
33+
bufPool: sync.Pool{
34+
New: func() any { return new(bytes.Buffer) },
35+
},
3136
}
3237
}
3338

@@ -37,23 +42,41 @@ func (encoder *Encoder) SetPacketEncoder(id abstraction.PacketId, enc PacketEnco
3742
encoder.logger.Trace().Uint16("id", uint16(id)).Type("encoder", enc).Msg("set encoder")
3843
}
3944

40-
// Encode encodes the provided packet into a byte slice, returning any errors
41-
func (encoder *Encoder) Encode(packet abstraction.Packet) ([]byte, error) {
45+
// Encode encodes the provided packet into a pooled buffer. Callers must release
46+
// the buffer via ReleaseBuffer once they are done using the returned data.
47+
func (encoder *Encoder) Encode(packet abstraction.Packet) (*bytes.Buffer, error) {
4248
enc, ok := encoder.idToEncoder[packet.Id()]
4349
if !ok {
4450
encoder.logger.Warn().Uint16("id", uint16(packet.Id())).Msg("no encoder set")
4551
return nil, ErrUnexpectedId{Id: packet.Id()}
4652
}
4753

48-
buffer := new(bytes.Buffer)
54+
bufferAny := encoder.bufPool.Get()
55+
buffer := bufferAny.(*bytes.Buffer)
56+
buffer.Reset()
4957

5058
err := binary.Write(buffer, encoder.endianness, packet.Id())
5159
if err != nil {
5260
encoder.logger.Error().Stack().Err(err).Uint16("id", uint16(packet.Id())).Msg("buffering id")
53-
return buffer.Bytes(), err
61+
encoder.ReleaseBuffer(buffer)
62+
return nil, err
5463
}
5564

5665
encoder.logger.Debug().Uint16("id", uint16(packet.Id())).Type("encoder", enc).Msg("encoding")
5766
err = enc.Encode(packet, buffer)
58-
return buffer.Bytes(), err
67+
if err != nil {
68+
encoder.ReleaseBuffer(buffer)
69+
return nil, err
70+
}
71+
72+
return buffer, nil
73+
}
74+
75+
// ReleaseBuffer returns a buffer obtained from Encode back to the pool.
76+
func (encoder *Encoder) ReleaseBuffer(buffer *bytes.Buffer) {
77+
if buffer == nil {
78+
return
79+
}
80+
buffer.Reset()
81+
encoder.bufPool.Put(buffer)
5982
}

backend/pkg/transport/presentation/encoder_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,13 @@ func TestEncoder(t *testing.T) {
379379

380380
output := make([]byte, 0, len(test.output))
381381
for i := 0; i < len(test.input); i++ {
382-
encoded, err := encoder.Encode(test.input[i])
382+
buf, err := encoder.Encode(test.input[i])
383383
if err != nil {
384384
t.Fatalf("\nError encoding (%d) packet: %s\n", i+1, err)
385385
}
386386

387-
output = append(output, encoded...)
387+
output = append(output, buf.Bytes()...)
388+
encoder.ReleaseBuffer(buf)
388389

389390
}
390391

backend/pkg/transport/transport.go

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type Transport struct {
3131
decoder *presentation.Decoder
3232
encoder *presentation.Encoder
3333

34-
connectionsMx *sync.Mutex
34+
connectionsMx *sync.RWMutex
3535
connections map[abstraction.TransportTarget]net.Conn
3636

3737
ipToTarget map[string]abstraction.TransportTarget
@@ -45,22 +45,27 @@ type Transport struct {
4545

4646
logger zerolog.Logger
4747

48+
byteReaderPool sync.Pool
49+
4850
errChan chan error
4951
}
5052

53+
// For tests
54+
var zeroTime time.Time
55+
5156
// HandleClient connects to the specified client and handles its messages. This method blocks.
5257
// This method will continuously try to reconnect to the client if it disconnects,
5358
// applying exponential backoff between attempts.
5459
func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string) error {
5560
client := tcp.NewClient(remote, config, transport.logger)
56-
defer transport.logger.Warn().Str("remoteAddress", remote).Msg("abort connection")
61+
clientLogger := transport.logger.With().Str("remoteAddress", remote).Logger()
62+
defer clientLogger.Warn().Msg("abort connection")
5763
var hasConnected = false
5864

5965
for {
6066
conn, err := client.Dial()
6167
if err != nil {
62-
transport.logger.Debug().Stack().Err(err).Str("remoteAddress", remote).Msg("dial failed")
63-
68+
clientLogger.Debug().Stack().Err(err).Msg("dial failed")
6469
// Only return if reconnection is disabled
6570
if !config.TryReconnect {
6671
if hasConnected {
@@ -73,7 +78,7 @@ func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string)
7378
// For ErrTooManyRetries, we still want to continue retrying
7479
// The client will reset its retry counter on the next Dial() call
7580
if _, ok := err.(tcp.ErrTooManyRetries); ok {
76-
transport.logger.Warn().Str("remoteAddress", remote).Msg("reached max retries, will continue attempting to reconnect")
81+
clientLogger.Warn().Msg("reached max retries, will continue attempting to reconnect")
7782
// Add a longer delay before restarting the retry cycle
7883
time.Sleep(config.ConnectionBackoffFunction(config.MaxConnectionRetries))
7984
}
@@ -85,12 +90,12 @@ func (transport *Transport) HandleClient(config tcp.ClientConfig, remote string)
8590

8691
err = transport.handleTCPConn(conn)
8792
if errors.Is(err, error(ErrTargetAlreadyConnected{})) {
88-
transport.logger.Warn().Stack().Err(err).Str("remoteAddress", remote).Msg("multiple connections for same target")
93+
clientLogger.Warn().Stack().Err(err).Msg("multiple connections for same target")
8994
transport.errChan <- err
9095
return err
9196
}
9297
if err != nil {
93-
transport.logger.Debug().Stack().Err(err).Str("remoteAddress", remote).Msg("connection lost")
98+
clientLogger.Debug().Stack().Err(err).Msg("connection lost")
9499
if !config.TryReconnect {
95100
transport.SendFault()
96101
transport.errChan <- err
@@ -254,6 +259,10 @@ func (transport *Transport) readLoopTCPConn(conn net.Conn, logger zerolog.Logger
254259

255260
logger.Trace().Type("type", packet).Msg("packet")
256261
transport.api.Notification(NewPacketNotification(packet, from, to, time.Now()))
262+
263+
if dataPacket, ok := packet.(*data.Packet); ok {
264+
data.ReleasePacket(dataPacket)
265+
}
257266
}
258267
}()
259268
}
@@ -289,30 +298,31 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error {
289298

290299
if message.Id() == 0 {
291300
eventLogger.Info().Msg("broadcasting packet id 0")
292-
data, err := transport.encoder.Encode(message.Packet)
301+
buf, err := transport.encoder.Encode(message.Packet)
293302
if err != nil {
294303
eventLogger.Error().Stack().Err(err).Msg("encode")
295304
transport.errChan <- err
296305
return err
297306
}
307+
defer transport.encoder.ReleaseBuffer(buf)
308+
data := buf.Bytes()
298309

299-
transport.connectionsMx.Lock()
300-
defer transport.connectionsMx.Unlock()
310+
transport.connectionsMx.RLock()
311+
defer transport.connectionsMx.RUnlock()
301312
for target, conn := range transport.connections {
302-
eventLogger := eventLogger.With().Str("target", string(target)).Logger()
303-
313+
targetName := string(target)
304314
totalWritten := 0
305315
for totalWritten < len(data) {
306316
n, err := conn.Write(data[totalWritten:])
307-
eventLogger.Trace().Int("amount", n).Msg("written chunk")
317+
eventLogger.Trace().Str("target", targetName).Int("amount", n).Msg("written chunk")
308318
totalWritten += n
309319
if err != nil {
310-
eventLogger.Error().Stack().Err(err).Msg("write")
320+
eventLogger.Error().Str("target", targetName).Stack().Err(err).Msg("write")
311321
transport.errChan <- err
312322
return err
313323
}
314324
}
315-
eventLogger.Info().Msg("sent")
325+
eventLogger.Info().Str("target", targetName).Msg("sent")
316326
}
317327
return nil
318328
}
@@ -328,11 +338,11 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error {
328338
eventLogger.Info().Msg("sending")
329339

330340
conn, err := func() (net.Conn, error) {
331-
transport.connectionsMx.Lock()
332-
defer transport.connectionsMx.Unlock()
341+
transport.connectionsMx.RLock()
342+
defer transport.connectionsMx.RUnlock()
333343
conn, ok := transport.connections[target]
334344
if !ok {
335-
eventLogger.Warn().Msg("target not connected")
345+
eventLogger.Warn().Msg("target not connected")
336346

337347
err := ErrConnClosed{Target: target}
338348
return nil, err
@@ -344,12 +354,14 @@ func (transport *Transport) handlePacketEvent(message PacketMessage) error {
344354
return err
345355
}
346356

347-
data, err := transport.encoder.Encode(message.Packet)
357+
buf, err := transport.encoder.Encode(message.Packet)
348358
if err != nil {
349359
eventLogger.Error().Stack().Err(err).Msg("encode")
350360
transport.errChan <- err
351361
return err
352362
}
363+
defer transport.encoder.ReleaseBuffer(buf)
364+
data := buf.Bytes()
353365

354366
totalWritten := 0
355367
for totalWritten < len(data) {
@@ -413,14 +425,30 @@ func (transport *Transport) HandleUDPServer(server *udp.Server) {
413425
}
414426
}
415427

428+
func (transport *Transport) replicateFault(packet abstraction.Packet, logger zerolog.Logger) {
429+
logger.Info().Msg("replicating packet with id 0 to all boards")
430+
err := transport.handlePacketEvent(NewPacketMessage(packet))
431+
if err != nil {
432+
logger.Error().Err(err).Msg("failed to replicate packet")
433+
}
434+
}
435+
416436
// handleUDPPacket handles a single UDP packet received by the UDP server
417437
func (transport *Transport) handleUDPPacket(udpPacket udp.Packet) {
418438
srcAddr := fmt.Sprintf("%s:%d", udpPacket.SourceIP, udpPacket.SourcePort)
419439
dstAddr := fmt.Sprintf("%s:%d", udpPacket.DestIP, udpPacket.DestPort)
420440

421441
// Create a reader from the payload
422-
reader := bytes.NewReader(udpPacket.Payload)
423-
442+
readerAny := transport.byteReaderPool.Get()
443+
var reader *bytes.Reader
444+
if readerAny != nil {
445+
reader = readerAny.(*bytes.Reader)
446+
reader.Reset(udpPacket.Payload)
447+
} else {
448+
reader = bytes.NewReader(udpPacket.Payload)
449+
}
450+
defer transport.byteReaderPool.Put(reader)
451+
424452
// Decode the packet
425453
packet, err := transport.decoder.DecodeNext(reader)
426454
if err != nil {
@@ -435,15 +463,15 @@ func (transport *Transport) handleUDPPacket(udpPacket udp.Packet) {
435463

436464
// Intercept packets with id == 0 and replicate
437465
if transport.propagateFault && packet.Id() == 0 {
438-
transport.logger.Info().Msg("replicating packet with id 0 to all boards")
439-
err := transport.handlePacketEvent(NewPacketMessage(packet))
440-
if err != nil {
441-
transport.logger.Error().Err(err).Msg("failed to replicate packet")
442-
}
466+
transport.replicateFault(packet, transport.logger)
443467
}
444468

445469
// Send notification
446470
transport.api.Notification(NewPacketNotification(packet, srcAddr, dstAddr, udpPacket.Timestamp))
471+
472+
if dataPacket, ok := packet.(*data.Packet); ok {
473+
data.ReleasePacket(dataPacket)
474+
}
447475
}
448476

449477
// handleConversation is called when the sniffer detects a new conversation and handles its specific packets
@@ -463,14 +491,15 @@ func (transport *Transport) handleConversation(socket network.Socket, reader io.
463491

464492
// Intercept packets with id == 0 and replicate
465493
if transport.propagateFault && packet.Id() == 0 {
466-
conversationLogger.Info().Msg("replicating packet with id 0 to all boards")
467-
err := transport.handlePacketEvent(NewPacketMessage(packet))
468-
if err != nil {
469-
conversationLogger.Error().Err(err).Msg("failed to replicate packet")
470-
}
494+
transport.replicateFault(packet, transport.logger)
471495
}
472496

497+
// Send notification
473498
transport.api.Notification(NewPacketNotification(packet, srcAddr, dstAddr, time.Now()))
499+
500+
if dataPacket, ok := packet.(*data.Packet); ok {
501+
data.ReleasePacket(dataPacket)
502+
}
474503
}
475504
}()
476505
}

0 commit comments

Comments
 (0)