Skip to content

Commit 9719626

Browse files
committed
Update UDP: Ensure propagation of control message to pool.Message
This commit enhances the UDP functionality, ensuring proper dissemination of control messages to pool.Message for improved network coordination and responsiveness
1 parent d63d5c8 commit 9719626

File tree

9 files changed

+321
-52
lines changed

9 files changed

+321
-52
lines changed

dtls/server/session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func (s *Session) Run(cc *client.Conn) (err error) {
146146
return fmt.Errorf("cannot read from connection: %w", err)
147147
}
148148
readBuf = readBuf[:readLen]
149-
err = cc.Process(readBuf)
149+
err = cc.Process(nil, readBuf)
150150
if err != nil {
151151
return err
152152
}

message/pool/message.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
multierror "github.com/hashicorp/go-multierror"
1111
"github.com/plgd-dev/go-coap/v3/message"
1212
"github.com/plgd-dev/go-coap/v3/message/codes"
13+
"github.com/plgd-dev/go-coap/v3/net"
1314
"go.uber.org/atomic"
1415
)
1516

@@ -26,6 +27,7 @@ type Message struct {
2627
// Context context of request.
2728
ctx context.Context
2829
msg message.Message
30+
controlMessage *net.ControlMessage // control message for UDP
2931
hijacked atomic.Bool
3032
isModified bool
3133
valueBuffer []byte
@@ -73,6 +75,22 @@ func (r *Message) SetMessage(message message.Message) {
7375
r.isModified = true
7476
}
7577

78+
func (r *Message) SetControlMessage(cm *net.ControlMessage) {
79+
r.controlMessage = cm
80+
}
81+
82+
func (r *Message) ControlMessage() *net.ControlMessage {
83+
return r.controlMessage
84+
}
85+
86+
// UpsertControlMessage set value only when origin value is not set.
87+
func (r *Message) UpsertControlMessage(cm *net.ControlMessage) {
88+
if r.controlMessage != nil {
89+
return
90+
}
91+
r.SetControlMessage(cm)
92+
}
93+
7694
// SetMessageID only 0 to 2^16-1 are valid.
7795
func (r *Message) SetMessageID(mid int32) {
7896
r.msg.MessageID = mid
@@ -120,6 +138,7 @@ func (r *Message) Reset() {
120138
r.valueBuffer = r.origValueBuffer
121139
r.body = nil
122140
r.isModified = false
141+
r.controlMessage = nil
123142
if cap(r.bufferMarshal) > 1024 {
124143
r.bufferMarshal = make([]byte, 256)
125144
}
@@ -568,6 +587,7 @@ func (r *Message) Clone(msg *Message) error {
568587
msg.ResetOptionsTo(r.Options())
569588
msg.SetType(r.Type())
570589
msg.SetMessageID(r.MessageID())
590+
msg.SetControlMessage(r.ControlMessage())
571591

572592
if r.Body() != nil {
573593
buf := bytes.NewBuffer(nil)

net/connUDP.go

Lines changed: 133 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,36 @@ type UDPConn struct {
2424
}
2525

2626
type ControlMessage struct {
27+
Dst net.IP // destination address, receiving only
2728
Src net.IP // source address, specifying only
2829
IfIndex int // interface index, must be 1 <= value when specifying
2930
}
3031

32+
func (c *ControlMessage) String() string {
33+
if c == nil {
34+
return ""
35+
}
36+
var sb strings.Builder
37+
if c.Dst != nil {
38+
sb.WriteString(fmt.Sprintf("Dst: %s, ", c.Dst))
39+
}
40+
if c.Src != nil {
41+
sb.WriteString(fmt.Sprintf("Src: %s, ", c.Src))
42+
}
43+
if c.IfIndex >= 1 {
44+
sb.WriteString(fmt.Sprintf("IfIndex: %d, ", c.IfIndex))
45+
}
46+
return sb.String()
47+
}
48+
49+
// GetIfIndex returns the interface index of the network interface. 0 means no interface index specified.
50+
func (c *ControlMessage) GetIfIndex() int {
51+
if c == nil {
52+
return 0
53+
}
54+
return c.IfIndex
55+
}
56+
3157
type packetConn interface {
3258
SetWriteDeadline(t time.Time) error
3359
WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error)
@@ -36,22 +62,34 @@ type packetConn interface {
3662
SetMulticastLoopback(on bool) error
3763
JoinGroup(ifi *net.Interface, group net.Addr) error
3864
LeaveGroup(ifi *net.Interface, group net.Addr) error
65+
ReadFrom(b []byte) (n int, cm *ControlMessage, src net.Addr, err error)
3966
}
4067

4168
type packetConnIPv4 struct {
42-
packetConnIPv4 *ipv4.PacketConn
69+
packetConn *ipv4.PacketConn
70+
controlMessageNotSupported bool
4371
}
4472

45-
func newPacketConnIPv4(p *ipv4.PacketConn) *packetConnIPv4 {
46-
return &packetConnIPv4{p}
73+
func isNotImplemented(err error) bool {
74+
return strings.Contains(err.Error(), "not implemented on")
75+
}
76+
77+
func newPacketConnIPv4(p *ipv4.PacketConn) (*packetConnIPv4, error) {
78+
if err := p.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface|ipv4.FlagSrc, true); err != nil {
79+
if isNotImplemented(err) {
80+
return &packetConnIPv4{packetConn: p, controlMessageNotSupported: true}, nil
81+
}
82+
return nil, err
83+
}
84+
return &packetConnIPv4{packetConn: p, controlMessageNotSupported: false}, nil
4785
}
4886

4987
func (p *packetConnIPv4) SetMulticastInterface(ifi *net.Interface) error {
50-
return p.packetConnIPv4.SetMulticastInterface(ifi)
88+
return p.packetConn.SetMulticastInterface(ifi)
5189
}
5290

5391
func (p *packetConnIPv4) SetWriteDeadline(t time.Time) error {
54-
return p.packetConnIPv4.SetWriteDeadline(t)
92+
return p.packetConn.SetWriteDeadline(t)
5593
}
5694

5795
func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) {
@@ -62,39 +100,78 @@ func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n
62100
IfIndex: cm.IfIndex,
63101
}
64102
}
65-
return p.packetConnIPv4.WriteTo(b, c, dst)
103+
return p.packetConn.WriteTo(b, c, dst)
104+
}
105+
106+
func (p *packetConnIPv4) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) {
107+
n, cm, src, err := p.packetConn.ReadFrom(b)
108+
if err != nil {
109+
return -1, nil, nil, err
110+
}
111+
var controlMessage *ControlMessage
112+
if p.controlMessageNotSupported && cm != nil {
113+
controlMessage = &ControlMessage{
114+
Dst: cm.Dst,
115+
Src: cm.Src,
116+
IfIndex: cm.IfIndex,
117+
}
118+
}
119+
return n, controlMessage, src, err
66120
}
67121

68122
func (p *packetConnIPv4) SetMulticastHopLimit(hoplim int) error {
69-
return p.packetConnIPv4.SetMulticastTTL(hoplim)
123+
return p.packetConn.SetMulticastTTL(hoplim)
70124
}
71125

72126
func (p *packetConnIPv4) SetMulticastLoopback(on bool) error {
73-
return p.packetConnIPv4.SetMulticastLoopback(on)
127+
return p.packetConn.SetMulticastLoopback(on)
74128
}
75129

76130
func (p *packetConnIPv4) JoinGroup(ifi *net.Interface, group net.Addr) error {
77-
return p.packetConnIPv4.JoinGroup(ifi, group)
131+
return p.packetConn.JoinGroup(ifi, group)
78132
}
79133

80134
func (p *packetConnIPv4) LeaveGroup(ifi *net.Interface, group net.Addr) error {
81-
return p.packetConnIPv4.LeaveGroup(ifi, group)
135+
return p.packetConn.LeaveGroup(ifi, group)
82136
}
83137

84138
type packetConnIPv6 struct {
85-
packetConnIPv6 *ipv6.PacketConn
139+
packetConn *ipv6.PacketConn
140+
controlMessageNotSupported bool
86141
}
87142

88-
func newPacketConnIPv6(p *ipv6.PacketConn) *packetConnIPv6 {
89-
return &packetConnIPv6{p}
143+
func newPacketConnIPv6(p *ipv6.PacketConn) (*packetConnIPv6, error) {
144+
if err := p.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface|ipv6.FlagSrc, true); err != nil {
145+
if isNotImplemented(err) {
146+
return &packetConnIPv6{packetConn: p, controlMessageNotSupported: true}, nil
147+
}
148+
return nil, err
149+
}
150+
return &packetConnIPv6{packetConn: p}, nil
90151
}
91152

92153
func (p *packetConnIPv6) SetMulticastInterface(ifi *net.Interface) error {
93-
return p.packetConnIPv6.SetMulticastInterface(ifi)
154+
return p.packetConn.SetMulticastInterface(ifi)
94155
}
95156

96157
func (p *packetConnIPv6) SetWriteDeadline(t time.Time) error {
97-
return p.packetConnIPv6.SetWriteDeadline(t)
158+
return p.packetConn.SetWriteDeadline(t)
159+
}
160+
161+
func (p *packetConnIPv6) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) {
162+
n, cm, src, err := p.packetConn.ReadFrom(b)
163+
if err != nil {
164+
return -1, nil, nil, err
165+
}
166+
var controlMessage *ControlMessage
167+
if p.controlMessageNotSupported && cm != nil {
168+
controlMessage = &ControlMessage{
169+
Dst: cm.Dst,
170+
Src: cm.Src,
171+
IfIndex: cm.IfIndex,
172+
}
173+
}
174+
return n, controlMessage, src, err
98175
}
99176

100177
func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) {
@@ -105,27 +182,23 @@ func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n
105182
IfIndex: cm.IfIndex,
106183
}
107184
}
108-
return p.packetConnIPv6.WriteTo(b, c, dst)
185+
return p.packetConn.WriteTo(b, c, dst)
109186
}
110187

111188
func (p *packetConnIPv6) SetMulticastHopLimit(hoplim int) error {
112-
return p.packetConnIPv6.SetMulticastHopLimit(hoplim)
189+
return p.packetConn.SetMulticastHopLimit(hoplim)
113190
}
114191

115192
func (p *packetConnIPv6) SetMulticastLoopback(on bool) error {
116-
return p.packetConnIPv6.SetMulticastLoopback(on)
193+
return p.packetConn.SetMulticastLoopback(on)
117194
}
118195

119196
func (p *packetConnIPv6) JoinGroup(ifi *net.Interface, group net.Addr) error {
120-
return p.packetConnIPv6.JoinGroup(ifi, group)
197+
return p.packetConn.JoinGroup(ifi, group)
121198
}
122199

123200
func (p *packetConnIPv6) LeaveGroup(ifi *net.Interface, group net.Addr) error {
124-
return p.packetConnIPv6.LeaveGroup(ifi, group)
125-
}
126-
127-
func (p *packetConnIPv6) SetControlMessage(on bool) error {
128-
return p.packetConnIPv6.SetMulticastLoopback(on)
201+
return p.packetConn.LeaveGroup(ifi, group)
129202
}
130203

131204
// IsIPv6 return's true if addr is IPV6.
@@ -174,10 +247,17 @@ func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn {
174247
panic(fmt.Errorf("invalid address type(%T), UDP address expected", laddr))
175248
}
176249
var pc packetConn
250+
var err error
177251
if IsIPv6(addr.IP) {
178-
pc = newPacketConnIPv6(ipv6.NewPacketConn(c))
252+
pc, err = newPacketConnIPv6(ipv6.NewPacketConn(c))
253+
if err != nil {
254+
panic(fmt.Errorf("invalid UDP connection: %w", err))
255+
}
179256
} else {
180-
pc = newPacketConnIPv4(ipv4.NewPacketConn(c))
257+
pc, err = newPacketConnIPv4(ipv4.NewPacketConn(c))
258+
if err != nil {
259+
panic(fmt.Errorf("invalid UDP connection: %w", err))
260+
}
181261
}
182262

183263
return &UDPConn{
@@ -214,11 +294,18 @@ func (c *UDPConn) Close() error {
214294
func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLimit int, raddr *net.UDPAddr, buffer []byte) error {
215295
var pktSrc net.IP
216296
var p packetConn
297+
var err error
217298
if IsIPv6(raddr.IP) {
218-
p = newPacketConnIPv6(ipv6.NewPacketConn(c.connection))
299+
p, err = newPacketConnIPv6(ipv6.NewPacketConn(c.connection))
300+
if err != nil {
301+
return err
302+
}
219303
pktSrc = net.IPv6zero
220304
} else {
221-
p = newPacketConnIPv4(ipv4.NewPacketConn(c.connection))
305+
p, err = newPacketConnIPv4(ipv4.NewPacketConn(c.connection))
306+
if err != nil {
307+
return err
308+
}
222309
pktSrc = net.IPv4zero
223310
}
224311
if src != nil {
@@ -229,19 +316,22 @@ func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLim
229316
return ErrConnectionIsClosed
230317
}
231318
if iface != nil {
232-
if err := p.SetMulticastInterface(iface); err != nil {
319+
if err = p.SetMulticastInterface(iface); err != nil {
233320
return err
234321
}
235322
}
236-
if err := p.SetMulticastHopLimit(multicastHopLimit); err != nil {
323+
if err = p.SetMulticastHopLimit(multicastHopLimit); err != nil {
237324
return err
238325
}
239326

240-
var err error
241327
if iface != nil || src != nil {
328+
ifaceIdx := 0
329+
if iface != nil {
330+
ifaceIdx = iface.Index
331+
}
242332
_, err = p.WriteTo(buffer, &ControlMessage{
243333
Src: pktSrc,
244-
IfIndex: iface.Index,
334+
IfIndex: ifaceIdx,
245335
}, raddr)
246336
} else {
247337
_, err = p.WriteTo(buffer, nil, raddr)
@@ -409,7 +499,7 @@ func (c *UDPConn) writeMulticast(ctx context.Context, raddr *net.UDPAddr, buffer
409499
}
410500

411501
// WriteWithContext writes data with context.
412-
func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buffer []byte) error {
502+
func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, cm *ControlMessage, buffer []byte) error {
413503
if raddr == nil {
414504
return fmt.Errorf("cannot write with context: invalid raddr")
415505
}
@@ -422,7 +512,7 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff
422512
if c.closed.Load() {
423513
return ErrConnectionIsClosed
424514
}
425-
n, err := WriteToUDP(c.connection, raddr, buffer)
515+
n, err := c.packetConn.WriteTo(buffer, cm, raddr)
426516
if err != nil {
427517
return err
428518
}
@@ -434,20 +524,23 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff
434524
}
435525

436526
// ReadWithContext reads packet with context.
437-
func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *net.UDPAddr, error) {
527+
func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *ControlMessage, *net.UDPAddr, error) {
438528
select {
439529
case <-ctx.Done():
440-
return -1, nil, ctx.Err()
530+
return -1, nil, nil, ctx.Err()
441531
default:
442532
}
443533
if c.closed.Load() {
444-
return -1, nil, ErrConnectionIsClosed
534+
return -1, nil, nil, ErrConnectionIsClosed
445535
}
446-
n, s, err := c.connection.ReadFromUDP(buffer)
536+
n, cm, srcAddr, err := c.packetConn.ReadFrom(buffer)
447537
if err != nil {
448-
return -1, nil, fmt.Errorf("cannot read from udp connection: %w", err)
538+
return -1, nil, nil, fmt.Errorf("cannot read from udp connection: %w", err)
539+
}
540+
if udpAdrr, ok := srcAddr.(*net.UDPAddr); ok {
541+
return n, cm, udpAdrr, nil
449542
}
450-
return n, s, err
543+
return -1, nil, nil, fmt.Errorf("cannot read from udp connection: invalid srcAddr type %T", srcAddr)
451544
}
452545

453546
// SetMulticastLoopback sets whether transmitted multicast packets

0 commit comments

Comments
 (0)