Skip to content

Commit

Permalink
Update UDP: Ensure propagation of control message to pool.Message
Browse files Browse the repository at this point in the history
This commit enhances the UDP functionality, ensuring proper dissemination
of control messages to pool.Message for improved network coordination
and responsiveness
  • Loading branch information
jkralik committed Nov 12, 2023
1 parent d63d5c8 commit 03ea641
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 35 deletions.
2 changes: 1 addition & 1 deletion dtls/server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (s *Session) Run(cc *client.Conn) (err error) {
return fmt.Errorf("cannot read from connection: %w", err)
}
readBuf = readBuf[:readLen]
err = cc.Process(readBuf)
err = cc.Process(nil, readBuf)
if err != nil {
return err
}
Expand Down
20 changes: 20 additions & 0 deletions message/pool/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
multierror "github.com/hashicorp/go-multierror"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/net"
"go.uber.org/atomic"
)

Expand All @@ -26,6 +27,7 @@ type Message struct {
// Context context of request.
ctx context.Context
msg message.Message
controlMessage *net.ControlMessage // control message for UDP
hijacked atomic.Bool
isModified bool
valueBuffer []byte
Expand Down Expand Up @@ -73,6 +75,22 @@ func (r *Message) SetMessage(message message.Message) {
r.isModified = true
}

func (r *Message) SetControlMessage(cm *net.ControlMessage) {
r.controlMessage = cm
}

func (r *Message) ControlMessage() *net.ControlMessage {
return r.controlMessage
}

// UpsertControlMessage set value only when origin value is not set.
func (r *Message) UpsertControlMessage(cm *net.ControlMessage) {
if r.controlMessage != nil {
return
}
r.SetControlMessage(cm)
}

// SetMessageID only 0 to 2^16-1 are valid.
func (r *Message) SetMessageID(mid int32) {
r.msg.MessageID = mid
Expand Down Expand Up @@ -120,6 +138,7 @@ func (r *Message) Reset() {
r.valueBuffer = r.origValueBuffer
r.body = nil
r.isModified = false
r.controlMessage = nil
if cap(r.bufferMarshal) > 1024 {
r.bufferMarshal = make([]byte, 256)
}
Expand Down Expand Up @@ -568,6 +587,7 @@ func (r *Message) Clone(msg *Message) error {
msg.ResetOptionsTo(r.Options())
msg.SetType(r.Type())
msg.SetMessageID(r.MessageID())
msg.SetControlMessage(r.ControlMessage())

if r.Body() != nil {
buf := bytes.NewBuffer(nil)
Expand Down
115 changes: 92 additions & 23 deletions net/connUDP.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,36 @@ type UDPConn struct {
}

type ControlMessage struct {
Dst net.IP // destination address, receiving only
Src net.IP // source address, specifying only
IfIndex int // interface index, must be 1 <= value when specifying
}

func (c *ControlMessage) String() string {
if c == nil {
return ""
}
var sb strings.Builder
if c.Dst != nil {
sb.WriteString(fmt.Sprintf("Dst: %s, ", c.Dst))
}
if c.Src != nil {
sb.WriteString(fmt.Sprintf("Src: %s, ", c.Src))
}
if c.IfIndex >= 1 {
sb.WriteString(fmt.Sprintf("IfIndex: %d, ", c.IfIndex))
}
return sb.String()
}

// GetIfIndex returns the interface index of the network interface. 0 means no interface index specified.
func (c *ControlMessage) GetIfIndex() int {
if c == nil {
return 0
}
return c.IfIndex
}

type packetConn interface {
SetWriteDeadline(t time.Time) error
WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error)
Expand All @@ -36,14 +62,18 @@ type packetConn interface {
SetMulticastLoopback(on bool) error
JoinGroup(ifi *net.Interface, group net.Addr) error
LeaveGroup(ifi *net.Interface, group net.Addr) error
ReadFrom(b []byte) (n int, cm *ControlMessage, src net.Addr, err error)
}

type packetConnIPv4 struct {
packetConnIPv4 *ipv4.PacketConn
}

func newPacketConnIPv4(p *ipv4.PacketConn) *packetConnIPv4 {
return &packetConnIPv4{p}
func newPacketConnIPv4(p *ipv4.PacketConn) (*packetConnIPv4, error) {
if err := p.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface|ipv4.FlagSrc, true); err != nil {
return nil, err
}
return &packetConnIPv4{p}, nil
}

func (p *packetConnIPv4) SetMulticastInterface(ifi *net.Interface) error {
Expand All @@ -65,6 +95,18 @@ func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n
return p.packetConnIPv4.WriteTo(b, c, dst)
}

func (p *packetConnIPv4) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) {
n, cm, src, err := p.packetConnIPv4.ReadFrom(b)
if err != nil {
return -1, nil, nil, err
}
return n, &ControlMessage{
Dst: cm.Dst,
Src: cm.Src,
IfIndex: cm.IfIndex,
}, src, err
}

func (p *packetConnIPv4) SetMulticastHopLimit(hoplim int) error {
return p.packetConnIPv4.SetMulticastTTL(hoplim)
}
Expand All @@ -85,8 +127,11 @@ type packetConnIPv6 struct {
packetConnIPv6 *ipv6.PacketConn
}

func newPacketConnIPv6(p *ipv6.PacketConn) *packetConnIPv6 {
return &packetConnIPv6{p}
func newPacketConnIPv6(p *ipv6.PacketConn) (*packetConnIPv6, error) {
if err := p.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface|ipv6.FlagSrc, true); err != nil {
return nil, err
}
return &packetConnIPv6{p}, nil
}

func (p *packetConnIPv6) SetMulticastInterface(ifi *net.Interface) error {
Expand All @@ -97,6 +142,18 @@ func (p *packetConnIPv6) SetWriteDeadline(t time.Time) error {
return p.packetConnIPv6.SetWriteDeadline(t)
}

func (p *packetConnIPv6) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) {
n, cm, src, err := p.packetConnIPv6.ReadFrom(b)
if err != nil {
return -1, nil, nil, err
}
return n, &ControlMessage{
Dst: cm.Dst,
Src: cm.Src,
IfIndex: cm.IfIndex,
}, src, err
}

func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) {
var c *ipv6.ControlMessage
if cm != nil {
Expand Down Expand Up @@ -124,10 +181,6 @@ func (p *packetConnIPv6) LeaveGroup(ifi *net.Interface, group net.Addr) error {
return p.packetConnIPv6.LeaveGroup(ifi, group)
}

func (p *packetConnIPv6) SetControlMessage(on bool) error {
return p.packetConnIPv6.SetMulticastLoopback(on)
}

// IsIPv6 return's true if addr is IPV6.
func IsIPv6(addr net.IP) bool {
if ip := addr.To16(); ip != nil && ip.To4() == nil {
Expand Down Expand Up @@ -174,10 +227,17 @@ func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn {
panic(fmt.Errorf("invalid address type(%T), UDP address expected", laddr))
}
var pc packetConn
var err error
if IsIPv6(addr.IP) {
pc = newPacketConnIPv6(ipv6.NewPacketConn(c))
pc, err = newPacketConnIPv6(ipv6.NewPacketConn(c))
if err != nil {
panic(fmt.Errorf("invalid UDP connection: %w", err))
}
} else {
pc = newPacketConnIPv4(ipv4.NewPacketConn(c))
pc, err = newPacketConnIPv4(ipv4.NewPacketConn(c))
if err != nil {
panic(fmt.Errorf("invalid UDP connection: %w", err))
}
}

return &UDPConn{
Expand Down Expand Up @@ -214,11 +274,18 @@ func (c *UDPConn) Close() error {
func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLimit int, raddr *net.UDPAddr, buffer []byte) error {
var pktSrc net.IP
var p packetConn
var err error
if IsIPv6(raddr.IP) {
p = newPacketConnIPv6(ipv6.NewPacketConn(c.connection))
p, err = newPacketConnIPv6(ipv6.NewPacketConn(c.connection))
if err != nil {
return err
}
pktSrc = net.IPv6zero
} else {
p = newPacketConnIPv4(ipv4.NewPacketConn(c.connection))
p, err = newPacketConnIPv4(ipv4.NewPacketConn(c.connection))
if err != nil {
return err
}
pktSrc = net.IPv4zero
}
if src != nil {
Expand All @@ -229,15 +296,14 @@ func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLim
return ErrConnectionIsClosed
}
if iface != nil {
if err := p.SetMulticastInterface(iface); err != nil {
if err = p.SetMulticastInterface(iface); err != nil {
return err
}
}
if err := p.SetMulticastHopLimit(multicastHopLimit); err != nil {
if err = p.SetMulticastHopLimit(multicastHopLimit); err != nil {
return err
}

var err error
if iface != nil || src != nil {
_, err = p.WriteTo(buffer, &ControlMessage{
Src: pktSrc,
Expand Down Expand Up @@ -409,7 +475,7 @@ func (c *UDPConn) writeMulticast(ctx context.Context, raddr *net.UDPAddr, buffer
}

// WriteWithContext writes data with context.
func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buffer []byte) error {
func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, cm *ControlMessage, buffer []byte) error {
if raddr == nil {
return fmt.Errorf("cannot write with context: invalid raddr")
}
Expand All @@ -422,7 +488,7 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff
if c.closed.Load() {
return ErrConnectionIsClosed
}
n, err := WriteToUDP(c.connection, raddr, buffer)
n, err := c.packetConn.WriteTo(buffer, cm, raddr)
if err != nil {
return err
}
Expand All @@ -434,20 +500,23 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff
}

// ReadWithContext reads packet with context.
func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *net.UDPAddr, error) {
func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *ControlMessage, *net.UDPAddr, error) {
select {
case <-ctx.Done():
return -1, nil, ctx.Err()
return -1, nil, nil, ctx.Err()
default:
}
if c.closed.Load() {
return -1, nil, ErrConnectionIsClosed
return -1, nil, nil, ErrConnectionIsClosed
}
n, s, err := c.connection.ReadFromUDP(buffer)
n, cm, srcAddr, err := c.packetConn.ReadFrom(buffer)
if err != nil {
return -1, nil, fmt.Errorf("cannot read from udp connection: %w", err)
return -1, nil, nil, fmt.Errorf("cannot read from udp connection: %w", err)
}
if udpAdrr, ok := srcAddr.(*net.UDPAddr); ok {
return n, cm, udpAdrr, nil
}
return n, s, err
return -1, nil, nil, fmt.Errorf("cannot read from udp connection: invalid srcAddr type %T", srcAddr)
}

// SetMulticastLoopback sets whether transmitted multicast packets
Expand Down
6 changes: 3 additions & 3 deletions net/connUDP_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ func TestUDPConnWriteWithContext(t *testing.T) {

go func() {
b := make([]byte, 1024)
_, _, errR := c2.ReadWithContext(ctx, b)
_, _, _, errR := c2.ReadWithContext(ctx, b)
if errR != nil {
return
}
}()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err = c1.WriteWithContext(tt.args.ctx, tt.args.udpCtx, tt.args.buffer)
err = c1.WriteWithContext(tt.args.ctx, tt.args.udpCtx, nil, tt.args.buffer)

c1.LocalAddr()
c1.RemoteAddr()
Expand Down Expand Up @@ -201,7 +201,7 @@ func TestUDPConnwriteMulticastWithContext(t *testing.T) {
wg.Add(1)
go func() {
b := make([]byte, 1024)
n, _, errR := c2.ReadWithContext(ctx, b)
n, _, _, errR := c2.ReadWithContext(ctx, b)
assert.NoError(t, errR)
if n > 0 {
b = b[:n]
Expand Down
15 changes: 14 additions & 1 deletion udp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con
}()
resp := cc.AcquireMessage(cc.Context())
resp.SetToken(req.Token())
ifIndex := req.ControlMessage().GetIfIndex()
w := responsewriter.New(resp, cc, req.Options()...)
defer func() {
cc.ReleaseMessage(w.Message())
Expand All @@ -730,6 +731,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con
// nothing to send
return
}
upsertInterfaceToMessage(w.Message(), ifIndex)
errW := cc.writeMessageAsync(w.Message())
if errW != nil {
cc.closeConnection()
Expand All @@ -741,6 +743,15 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess
cc.sendPong(w, r)
}

func upsertInterfaceToMessage(m *pool.Message, ifIndex int) {
if ifIndex >= 1 {
cm := coapNet.ControlMessage{
IfIndex: ifIndex,
}
m.UpsertControlMessage(&cm)
}
}

func (cc *Conn) handleSpecialMessages(r *pool.Message) bool {
// ping request
if r.Code() == codes.Empty && r.Type() == message.Confirmable && len(r.Token()) == 0 && len(r.Options()) == 0 && r.Body() == nil {
Expand All @@ -752,6 +763,7 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool {
elem.ReleaseMessage(cc)
resp := cc.AcquireMessage(cc.Context())
resp.SetToken(r.Token())
upsertInterfaceToMessage(resp, r.ControlMessage().GetIfIndex())
w := responsewriter.New(resp, cc, r.Options()...)
defer func() {
cc.ReleaseMessage(w.Message())
Expand All @@ -769,7 +781,7 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool {
return false
}

func (cc *Conn) Process(datagram []byte) error {
func (cc *Conn) Process(cm *coapNet.ControlMessage, datagram []byte) error {
if uint32(len(datagram)) > cc.session.MaxMessageSize() {
return fmt.Errorf("max message size(%v) was exceeded %v", cc.session.MaxMessageSize(), len(datagram))
}
Expand All @@ -779,6 +791,7 @@ func (cc *Conn) Process(datagram []byte) error {
cc.ReleaseMessage(req)
return err
}
req.SetControlMessage(cm)
req.SetSequence(cc.Sequence())
cc.checkMyMessageID(req)
cc.inactivityMonitor.Notify()
Expand Down
2 changes: 1 addition & 1 deletion udp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ func TestClientKeepAliveMonitor(t *testing.T) {
go func() {
defer serverWg.Done()
for {
_, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024))
_, _, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024))
if errR != nil {
if errors.Is(errR, net.ErrClosed) {
return
Expand Down
2 changes: 1 addition & 1 deletion udp/server/discover.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (s *Server) DiscoveryRequest(req *pool.Message, address string, receiverFun
return err
}
} else {
err = c.WriteWithContext(req.Context(), addr, data)
err = c.WriteWithContext(req.Context(), addr, nil, data)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit 03ea641

Please sign in to comment.