@@ -24,10 +24,36 @@ type UDPConn struct {
24
24
}
25
25
26
26
type ControlMessage struct {
27
+ Dst net.IP // destination address, receiving only
27
28
Src net.IP // source address, specifying only
28
29
IfIndex int // interface index, must be 1 <= value when specifying
29
30
}
30
31
32
+ func (c * ControlMessage ) String () string {
33
+ if c == nil {
34
+ return ""
35
+ }
36
+ var sb strings.Builder
37
+ if c .Dst != nil {
38
+ sb .WriteString (fmt .Sprintf ("Dst: %s, " , c .Dst ))
39
+ }
40
+ if c .Src != nil {
41
+ sb .WriteString (fmt .Sprintf ("Src: %s, " , c .Src ))
42
+ }
43
+ if c .IfIndex >= 1 {
44
+ sb .WriteString (fmt .Sprintf ("IfIndex: %d, " , c .IfIndex ))
45
+ }
46
+ return sb .String ()
47
+ }
48
+
49
+ // GetIfIndex returns the interface index of the network interface. 0 means no interface index specified.
50
+ func (c * ControlMessage ) GetIfIndex () int {
51
+ if c == nil {
52
+ return 0
53
+ }
54
+ return c .IfIndex
55
+ }
56
+
31
57
type packetConn interface {
32
58
SetWriteDeadline (t time.Time ) error
33
59
WriteTo (b []byte , cm * ControlMessage , dst net.Addr ) (n int , err error )
@@ -36,22 +62,34 @@ type packetConn interface {
36
62
SetMulticastLoopback (on bool ) error
37
63
JoinGroup (ifi * net.Interface , group net.Addr ) error
38
64
LeaveGroup (ifi * net.Interface , group net.Addr ) error
65
+ ReadFrom (b []byte ) (n int , cm * ControlMessage , src net.Addr , err error )
39
66
}
40
67
41
68
type packetConnIPv4 struct {
42
- packetConnIPv4 * ipv4.PacketConn
69
+ packetConn * ipv4.PacketConn
70
+ controlMessageNotSupported bool
43
71
}
44
72
45
- func newPacketConnIPv4 (p * ipv4.PacketConn ) * packetConnIPv4 {
46
- return & packetConnIPv4 {p }
73
+ func isNotImplemented (err error ) bool {
74
+ return strings .Contains (err .Error (), "not implemented on" )
75
+ }
76
+
77
+ func newPacketConnIPv4 (p * ipv4.PacketConn ) (* packetConnIPv4 , error ) {
78
+ if err := p .SetControlMessage (ipv4 .FlagDst | ipv4 .FlagInterface | ipv4 .FlagSrc , true ); err != nil {
79
+ if isNotImplemented (err ) {
80
+ return & packetConnIPv4 {packetConn : p , controlMessageNotSupported : true }, nil
81
+ }
82
+ return nil , err
83
+ }
84
+ return & packetConnIPv4 {packetConn : p , controlMessageNotSupported : false }, nil
47
85
}
48
86
49
87
func (p * packetConnIPv4 ) SetMulticastInterface (ifi * net.Interface ) error {
50
- return p .packetConnIPv4 .SetMulticastInterface (ifi )
88
+ return p .packetConn .SetMulticastInterface (ifi )
51
89
}
52
90
53
91
func (p * packetConnIPv4 ) SetWriteDeadline (t time.Time ) error {
54
- return p .packetConnIPv4 .SetWriteDeadline (t )
92
+ return p .packetConn .SetWriteDeadline (t )
55
93
}
56
94
57
95
func (p * packetConnIPv4 ) WriteTo (b []byte , cm * ControlMessage , dst net.Addr ) (n int , err error ) {
@@ -62,39 +100,78 @@ func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n
62
100
IfIndex : cm .IfIndex ,
63
101
}
64
102
}
65
- return p .packetConnIPv4 .WriteTo (b , c , dst )
103
+ return p .packetConn .WriteTo (b , c , dst )
104
+ }
105
+
106
+ func (p * packetConnIPv4 ) ReadFrom (b []byte ) (int , * ControlMessage , net.Addr , error ) {
107
+ n , cm , src , err := p .packetConn .ReadFrom (b )
108
+ if err != nil {
109
+ return - 1 , nil , nil , err
110
+ }
111
+ var controlMessage * ControlMessage
112
+ if p .controlMessageNotSupported && cm != nil {
113
+ controlMessage = & ControlMessage {
114
+ Dst : cm .Dst ,
115
+ Src : cm .Src ,
116
+ IfIndex : cm .IfIndex ,
117
+ }
118
+ }
119
+ return n , controlMessage , src , err
66
120
}
67
121
68
122
func (p * packetConnIPv4 ) SetMulticastHopLimit (hoplim int ) error {
69
- return p .packetConnIPv4 .SetMulticastTTL (hoplim )
123
+ return p .packetConn .SetMulticastTTL (hoplim )
70
124
}
71
125
72
126
func (p * packetConnIPv4 ) SetMulticastLoopback (on bool ) error {
73
- return p .packetConnIPv4 .SetMulticastLoopback (on )
127
+ return p .packetConn .SetMulticastLoopback (on )
74
128
}
75
129
76
130
func (p * packetConnIPv4 ) JoinGroup (ifi * net.Interface , group net.Addr ) error {
77
- return p .packetConnIPv4 .JoinGroup (ifi , group )
131
+ return p .packetConn .JoinGroup (ifi , group )
78
132
}
79
133
80
134
func (p * packetConnIPv4 ) LeaveGroup (ifi * net.Interface , group net.Addr ) error {
81
- return p .packetConnIPv4 .LeaveGroup (ifi , group )
135
+ return p .packetConn .LeaveGroup (ifi , group )
82
136
}
83
137
84
138
type packetConnIPv6 struct {
85
- packetConnIPv6 * ipv6.PacketConn
139
+ packetConn * ipv6.PacketConn
140
+ controlMessageNotSupported bool
86
141
}
87
142
88
- func newPacketConnIPv6 (p * ipv6.PacketConn ) * packetConnIPv6 {
89
- return & packetConnIPv6 {p }
143
+ func newPacketConnIPv6 (p * ipv6.PacketConn ) (* packetConnIPv6 , error ) {
144
+ if err := p .SetControlMessage (ipv6 .FlagDst | ipv6 .FlagInterface | ipv6 .FlagSrc , true ); err != nil {
145
+ if isNotImplemented (err ) {
146
+ return & packetConnIPv6 {packetConn : p , controlMessageNotSupported : true }, nil
147
+ }
148
+ return nil , err
149
+ }
150
+ return & packetConnIPv6 {packetConn : p }, nil
90
151
}
91
152
92
153
func (p * packetConnIPv6 ) SetMulticastInterface (ifi * net.Interface ) error {
93
- return p .packetConnIPv6 .SetMulticastInterface (ifi )
154
+ return p .packetConn .SetMulticastInterface (ifi )
94
155
}
95
156
96
157
func (p * packetConnIPv6 ) SetWriteDeadline (t time.Time ) error {
97
- return p .packetConnIPv6 .SetWriteDeadline (t )
158
+ return p .packetConn .SetWriteDeadline (t )
159
+ }
160
+
161
+ func (p * packetConnIPv6 ) ReadFrom (b []byte ) (int , * ControlMessage , net.Addr , error ) {
162
+ n , cm , src , err := p .packetConn .ReadFrom (b )
163
+ if err != nil {
164
+ return - 1 , nil , nil , err
165
+ }
166
+ var controlMessage * ControlMessage
167
+ if p .controlMessageNotSupported && cm != nil {
168
+ controlMessage = & ControlMessage {
169
+ Dst : cm .Dst ,
170
+ Src : cm .Src ,
171
+ IfIndex : cm .IfIndex ,
172
+ }
173
+ }
174
+ return n , controlMessage , src , err
98
175
}
99
176
100
177
func (p * packetConnIPv6 ) WriteTo (b []byte , cm * ControlMessage , dst net.Addr ) (n int , err error ) {
@@ -105,27 +182,23 @@ func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n
105
182
IfIndex : cm .IfIndex ,
106
183
}
107
184
}
108
- return p .packetConnIPv6 .WriteTo (b , c , dst )
185
+ return p .packetConn .WriteTo (b , c , dst )
109
186
}
110
187
111
188
func (p * packetConnIPv6 ) SetMulticastHopLimit (hoplim int ) error {
112
- return p .packetConnIPv6 .SetMulticastHopLimit (hoplim )
189
+ return p .packetConn .SetMulticastHopLimit (hoplim )
113
190
}
114
191
115
192
func (p * packetConnIPv6 ) SetMulticastLoopback (on bool ) error {
116
- return p .packetConnIPv6 .SetMulticastLoopback (on )
193
+ return p .packetConn .SetMulticastLoopback (on )
117
194
}
118
195
119
196
func (p * packetConnIPv6 ) JoinGroup (ifi * net.Interface , group net.Addr ) error {
120
- return p .packetConnIPv6 .JoinGroup (ifi , group )
197
+ return p .packetConn .JoinGroup (ifi , group )
121
198
}
122
199
123
200
func (p * packetConnIPv6 ) LeaveGroup (ifi * net.Interface , group net.Addr ) error {
124
- return p .packetConnIPv6 .LeaveGroup (ifi , group )
125
- }
126
-
127
- func (p * packetConnIPv6 ) SetControlMessage (on bool ) error {
128
- return p .packetConnIPv6 .SetMulticastLoopback (on )
201
+ return p .packetConn .LeaveGroup (ifi , group )
129
202
}
130
203
131
204
// IsIPv6 return's true if addr is IPV6.
@@ -174,10 +247,17 @@ func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn {
174
247
panic (fmt .Errorf ("invalid address type(%T), UDP address expected" , laddr ))
175
248
}
176
249
var pc packetConn
250
+ var err error
177
251
if IsIPv6 (addr .IP ) {
178
- pc = newPacketConnIPv6 (ipv6 .NewPacketConn (c ))
252
+ pc , err = newPacketConnIPv6 (ipv6 .NewPacketConn (c ))
253
+ if err != nil {
254
+ panic (fmt .Errorf ("invalid UDP connection: %w" , err ))
255
+ }
179
256
} else {
180
- pc = newPacketConnIPv4 (ipv4 .NewPacketConn (c ))
257
+ pc , err = newPacketConnIPv4 (ipv4 .NewPacketConn (c ))
258
+ if err != nil {
259
+ panic (fmt .Errorf ("invalid UDP connection: %w" , err ))
260
+ }
181
261
}
182
262
183
263
return & UDPConn {
@@ -214,11 +294,18 @@ func (c *UDPConn) Close() error {
214
294
func (c * UDPConn ) writeToAddr (iface * net.Interface , src * net.IP , multicastHopLimit int , raddr * net.UDPAddr , buffer []byte ) error {
215
295
var pktSrc net.IP
216
296
var p packetConn
297
+ var err error
217
298
if IsIPv6 (raddr .IP ) {
218
- p = newPacketConnIPv6 (ipv6 .NewPacketConn (c .connection ))
299
+ p , err = newPacketConnIPv6 (ipv6 .NewPacketConn (c .connection ))
300
+ if err != nil {
301
+ return err
302
+ }
219
303
pktSrc = net .IPv6zero
220
304
} else {
221
- p = newPacketConnIPv4 (ipv4 .NewPacketConn (c .connection ))
305
+ p , err = newPacketConnIPv4 (ipv4 .NewPacketConn (c .connection ))
306
+ if err != nil {
307
+ return err
308
+ }
222
309
pktSrc = net .IPv4zero
223
310
}
224
311
if src != nil {
@@ -229,19 +316,22 @@ func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLim
229
316
return ErrConnectionIsClosed
230
317
}
231
318
if iface != nil {
232
- if err : = p .SetMulticastInterface (iface ); err != nil {
319
+ if err = p .SetMulticastInterface (iface ); err != nil {
233
320
return err
234
321
}
235
322
}
236
- if err : = p .SetMulticastHopLimit (multicastHopLimit ); err != nil {
323
+ if err = p .SetMulticastHopLimit (multicastHopLimit ); err != nil {
237
324
return err
238
325
}
239
326
240
- var err error
241
327
if iface != nil || src != nil {
328
+ ifaceIdx := 0
329
+ if iface != nil {
330
+ ifaceIdx = iface .Index
331
+ }
242
332
_ , err = p .WriteTo (buffer , & ControlMessage {
243
333
Src : pktSrc ,
244
- IfIndex : iface . Index ,
334
+ IfIndex : ifaceIdx ,
245
335
}, raddr )
246
336
} else {
247
337
_ , err = p .WriteTo (buffer , nil , raddr )
@@ -409,7 +499,7 @@ func (c *UDPConn) writeMulticast(ctx context.Context, raddr *net.UDPAddr, buffer
409
499
}
410
500
411
501
// WriteWithContext writes data with context.
412
- func (c * UDPConn ) WriteWithContext (ctx context.Context , raddr * net.UDPAddr , buffer []byte ) error {
502
+ func (c * UDPConn ) WriteWithContext (ctx context.Context , raddr * net.UDPAddr , cm * ControlMessage , buffer []byte ) error {
413
503
if raddr == nil {
414
504
return fmt .Errorf ("cannot write with context: invalid raddr" )
415
505
}
@@ -422,7 +512,7 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff
422
512
if c .closed .Load () {
423
513
return ErrConnectionIsClosed
424
514
}
425
- n , err := WriteToUDP ( c . connection , raddr , buffer )
515
+ n , err := c . packetConn . WriteTo ( buffer , cm , raddr )
426
516
if err != nil {
427
517
return err
428
518
}
@@ -434,20 +524,23 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff
434
524
}
435
525
436
526
// ReadWithContext reads packet with context.
437
- func (c * UDPConn ) ReadWithContext (ctx context.Context , buffer []byte ) (int , * net.UDPAddr , error ) {
527
+ func (c * UDPConn ) ReadWithContext (ctx context.Context , buffer []byte ) (int , * ControlMessage , * net.UDPAddr , error ) {
438
528
select {
439
529
case <- ctx .Done ():
440
- return - 1 , nil , ctx .Err ()
530
+ return - 1 , nil , nil , ctx .Err ()
441
531
default :
442
532
}
443
533
if c .closed .Load () {
444
- return - 1 , nil , ErrConnectionIsClosed
534
+ return - 1 , nil , nil , ErrConnectionIsClosed
445
535
}
446
- n , s , err := c .connection . ReadFromUDP (buffer )
536
+ n , cm , srcAddr , err := c .packetConn . ReadFrom (buffer )
447
537
if err != nil {
448
- return - 1 , nil , fmt .Errorf ("cannot read from udp connection: %w" , err )
538
+ return - 1 , nil , nil , fmt .Errorf ("cannot read from udp connection: %w" , err )
539
+ }
540
+ if udpAdrr , ok := srcAddr .(* net.UDPAddr ); ok {
541
+ return n , cm , udpAdrr , nil
449
542
}
450
- return n , s , err
543
+ return - 1 , nil , nil , fmt . Errorf ( "cannot read from udp connection: invalid srcAddr type %T" , srcAddr )
451
544
}
452
545
453
546
// SetMulticastLoopback sets whether transmitted multicast packets
0 commit comments