Skip to content

Commit 3ac36df

Browse files
committed
fixup! Update UDP: Ensure propagation of control message to pool.Message
1 parent 848a6cf commit 3ac36df

File tree

6 files changed

+226
-48
lines changed

6 files changed

+226
-48
lines changed

net/connUDP.go

Lines changed: 189 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -523,21 +523,106 @@ func (c *UDPConn) writeTo(raddr *net.UDPAddr, cm *ControlMessage, buffer []byte)
523523
return c.packetConn.WriteTo(buffer, cm, raddr)
524524
}
525525

526+
type UDPWriteCfg struct {
527+
Ctx context.Context
528+
RemoteAddr *net.UDPAddr
529+
ControlMessage *ControlMessage
530+
}
531+
532+
func (c *UDPWriteCfg) ApplyWrite(cfg *UDPWriteCfg) {
533+
if c.Ctx != nil {
534+
cfg.Ctx = c.Ctx
535+
}
536+
if c.RemoteAddr != nil {
537+
cfg.RemoteAddr = c.RemoteAddr
538+
}
539+
if c.ControlMessage != nil {
540+
cfg.ControlMessage = c.ControlMessage
541+
}
542+
}
543+
544+
type UDPWriteOption interface {
545+
ApplyWrite(cfg *UDPWriteCfg)
546+
}
547+
548+
type (
549+
UDPWriteApplyFunc func(cfg *UDPWriteCfg)
550+
UDPReadApplyFunc func(cfg *UDPReadCfg)
551+
)
552+
553+
type ReadWriteOptionHandler[F UDPWriteApplyFunc | UDPReadApplyFunc] struct {
554+
Func F
555+
}
556+
557+
func (o ReadWriteOptionHandler[F]) ApplyWrite(cfg *UDPWriteCfg) {
558+
switch f := any(o.Func).(type) {
559+
case UDPWriteApplyFunc:
560+
f(cfg)
561+
default:
562+
panic(fmt.Errorf("invalid option handler %T for UDP Write", o.Func))
563+
}
564+
}
565+
566+
func (o ReadWriteOptionHandler[F]) ApplyRead(cfg *UDPReadCfg) {
567+
switch f := any(o.Func).(type) {
568+
case UDPReadApplyFunc:
569+
f(cfg)
570+
default:
571+
panic(fmt.Errorf("invalid option handler %T for UDP Read", o.Func))
572+
}
573+
}
574+
575+
func writeOptionFunc(f UDPWriteApplyFunc) ReadWriteOptionHandler[UDPWriteApplyFunc] {
576+
return ReadWriteOptionHandler[UDPWriteApplyFunc]{
577+
Func: f,
578+
}
579+
}
580+
581+
type ContextOption struct {
582+
Ctx context.Context
583+
}
584+
585+
func (o ContextOption) ApplyWrite(cfg *UDPWriteCfg) {
586+
cfg.Ctx = o.Ctx
587+
}
588+
589+
// WithContext sets the context of operation.
590+
func WithContext(ctx context.Context) ContextOption {
591+
return ContextOption{Ctx: ctx}
592+
}
593+
594+
func (o ContextOption) ApplyRead(cfg *UDPReadCfg) {
595+
cfg.Ctx = o.Ctx
596+
}
597+
598+
// WithRemoteAddr sets the remote address to packet.
599+
func WithRemoteAddr(raddr *net.UDPAddr) UDPWriteOption {
600+
return writeOptionFunc(func(cfg *UDPWriteCfg) {
601+
cfg.RemoteAddr = raddr
602+
})
603+
}
604+
605+
// WithControlMessage sets the control message to packet.
606+
func WithControlMessage(cm *ControlMessage) UDPWriteOption {
607+
return writeOptionFunc(func(cfg *UDPWriteCfg) {
608+
cfg.ControlMessage = cm
609+
})
610+
}
611+
526612
// WriteWithContext writes data with context.
527-
func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, cm *ControlMessage, buffer []byte) error {
528-
if raddr == nil {
613+
func (c *UDPConn) writeWithCfg(buffer []byte, cfg UDPWriteCfg) error {
614+
if cfg.RemoteAddr == nil {
529615
return fmt.Errorf("cannot write with context: invalid raddr")
530616
}
531-
532617
select {
533-
case <-ctx.Done():
534-
return ctx.Err()
618+
case <-cfg.Ctx.Done():
619+
return cfg.Ctx.Err()
535620
default:
536621
}
537622
if c.closed.Load() {
538623
return ErrConnectionIsClosed
539624
}
540-
n, err := c.writeTo(raddr, cm, buffer)
625+
n, err := c.writeTo(cfg.RemoteAddr, cfg.ControlMessage, buffer)
541626
if err != nil {
542627
return err
543628
}
@@ -548,24 +633,114 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, cm *
548633
return nil
549634
}
550635

551-
// ReadWithContext reads packet with context.
552-
func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *ControlMessage, *net.UDPAddr, error) {
636+
// WriteWithOptions writes data with options. Via opts you can specify the remote address and control message.
637+
func (c *UDPConn) WriteWithOptions(buffer []byte, opts ...UDPWriteOption) error {
638+
cfg := UDPWriteCfg{
639+
Ctx: context.Background(),
640+
}
641+
addr := c.RemoteAddr()
642+
if addr != nil {
643+
if remoteAddr, ok := addr.(*net.UDPAddr); ok {
644+
cfg.RemoteAddr = remoteAddr
645+
}
646+
}
647+
for _, o := range opts {
648+
o.ApplyWrite(&cfg)
649+
}
650+
return c.writeWithCfg(buffer, cfg)
651+
}
652+
653+
// WriteWithContext writes data with context.
654+
func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buffer []byte) error {
655+
return c.WriteWithOptions(buffer, WithContext(ctx), WithRemoteAddr(raddr))
656+
}
657+
658+
type UDPReadCfg struct {
659+
Ctx context.Context
660+
RemoteAddr **net.UDPAddr
661+
ControlMessage **ControlMessage
662+
}
663+
664+
func (c *UDPReadCfg) ApplyRead(cfg *UDPReadCfg) {
665+
if c.Ctx != nil {
666+
cfg.Ctx = c.Ctx
667+
}
668+
if c.RemoteAddr != nil {
669+
cfg.RemoteAddr = c.RemoteAddr
670+
}
671+
if c.ControlMessage != nil {
672+
cfg.ControlMessage = c.ControlMessage
673+
}
674+
}
675+
676+
type UDPReadOption interface {
677+
ApplyRead(cfg *UDPReadCfg)
678+
}
679+
680+
func readOptionFunc(f UDPReadApplyFunc) UDPReadOption {
681+
return ReadWriteOptionHandler[UDPReadApplyFunc]{
682+
Func: f,
683+
}
684+
}
685+
686+
// WithGetRemoteAddr fills the remote address when reading succeeds.
687+
func WithGetRemoteAddr(raddr **net.UDPAddr) UDPReadOption {
688+
return readOptionFunc(func(cfg *UDPReadCfg) {
689+
cfg.RemoteAddr = raddr
690+
})
691+
}
692+
693+
// WithGetControlMessage fills the control message when reading succeeds.
694+
func WithGetControlMessage(cm **ControlMessage) UDPReadOption {
695+
return readOptionFunc(func(cfg *UDPReadCfg) {
696+
cfg.ControlMessage = cm
697+
})
698+
}
699+
700+
func (c *UDPConn) readWithCfg(buffer []byte, cfg UDPReadCfg) (int, error) {
553701
select {
554-
case <-ctx.Done():
555-
return -1, nil, nil, ctx.Err()
702+
case <-cfg.Ctx.Done():
703+
return -1, cfg.Ctx.Err()
556704
default:
557705
}
558706
if c.closed.Load() {
559-
return -1, nil, nil, ErrConnectionIsClosed
707+
return -1, ErrConnectionIsClosed
560708
}
561709
n, cm, srcAddr, err := c.packetConn.ReadFrom(buffer)
562710
if err != nil {
563-
return -1, nil, nil, fmt.Errorf("cannot read from udp connection: %w", err)
711+
return -1, fmt.Errorf("cannot read from udp connection: %w", err)
564712
}
565713
if udpAdrr, ok := srcAddr.(*net.UDPAddr); ok {
566-
return n, cm, udpAdrr, nil
714+
if cfg.RemoteAddr != nil {
715+
*cfg.RemoteAddr = udpAdrr
716+
}
717+
if cfg.ControlMessage != nil {
718+
*cfg.ControlMessage = cm
719+
}
720+
return n, nil
721+
}
722+
return -1, fmt.Errorf("cannot read from udp connection: invalid srcAddr type %T", srcAddr)
723+
}
724+
725+
// ReadWithOptions reads packet with options. Via opts you can get also the remote address and control message.
726+
func (c *UDPConn) ReadWithOptions(buffer []byte, opts ...UDPReadOption) (int, error) {
727+
cfg := UDPReadCfg{
728+
Ctx: context.Background(),
729+
}
730+
for _, o := range opts {
731+
o.ApplyRead(&cfg)
732+
}
733+
return c.readWithCfg(buffer, cfg)
734+
}
735+
736+
// ReadWithContext reads packet with context.
737+
func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *net.UDPAddr, error) {
738+
var remoteAddr *net.UDPAddr
739+
n, err := c.ReadWithOptions(buffer, WithContext(ctx), WithGetRemoteAddr(&remoteAddr))
740+
if err != nil {
741+
return -1, nil, err
567742
}
568-
return -1, nil, nil, fmt.Errorf("cannot read from udp connection: invalid srcAddr type %T", srcAddr)
743+
return n, remoteAddr, err
569744
}
570745

571746
// SetMulticastLoopback sets whether transmitted multicast packets

net/connUDP_internal_test.go

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ func TestUDPConnWriteWithContext(t *testing.T) {
2727
ctxCancel()
2828

2929
type args struct {
30-
ctx context.Context
31-
udpCtx *net.UDPAddr
32-
buffer []byte
30+
ctx context.Context
31+
udpAddr *net.UDPAddr
32+
buffer []byte
3333
}
3434
tests := []struct {
3535
name string
@@ -39,9 +39,9 @@ func TestUDPConnWriteWithContext(t *testing.T) {
3939
{
4040
name: "valid",
4141
args: args{
42-
ctx: context.Background(),
43-
udpCtx: b,
44-
buffer: []byte("hello world"),
42+
ctx: context.Background(),
43+
udpAddr: b,
44+
buffer: []byte("hello world"),
4545
},
4646
},
4747
{
@@ -76,15 +76,15 @@ func TestUDPConnWriteWithContext(t *testing.T) {
7676

7777
go func() {
7878
b := make([]byte, 1024)
79-
_, _, _, errR := c2.ReadWithContext(ctx, b)
79+
_, _, errR := c2.ReadWithContext(ctx, b)
8080
if errR != nil {
8181
return
8282
}
8383
}()
8484

8585
for _, tt := range tests {
8686
t.Run(tt.name, func(t *testing.T) {
87-
err = c1.WriteWithContext(tt.args.ctx, tt.args.udpCtx, nil, tt.args.buffer)
87+
err = c1.WriteWithContext(tt.args.ctx, tt.args.udpAddr, tt.args.buffer)
8888

8989
c1.LocalAddr()
9090
c1.RemoteAddr()
@@ -119,10 +119,10 @@ func TestUDPConnwriteMulticastWithContext(t *testing.T) {
119119
require.NotEmpty(t, iface)
120120

121121
type args struct {
122-
ctx context.Context
123-
udpCtx *net.UDPAddr
124-
buffer []byte
125-
opts []MulticastOption
122+
ctx context.Context
123+
udpAddr *net.UDPAddr
124+
buffer []byte
125+
opts []MulticastOption
126126
}
127127
tests := []struct {
128128
name string
@@ -132,36 +132,36 @@ func TestUDPConnwriteMulticastWithContext(t *testing.T) {
132132
{
133133
name: "valid all interfaces",
134134
args: args{
135-
ctx: context.Background(),
136-
udpCtx: b,
137-
buffer: payload,
138-
opts: []MulticastOption{WithAllMulticastInterface()},
135+
ctx: context.Background(),
136+
udpAddr: b,
137+
buffer: payload,
138+
opts: []MulticastOption{WithAllMulticastInterface()},
139139
},
140140
},
141141
{
142142
name: "valid any interface",
143143
args: args{
144-
ctx: context.Background(),
145-
udpCtx: b,
146-
buffer: payload,
147-
opts: []MulticastOption{WithAnyMulticastInterface()},
144+
ctx: context.Background(),
145+
udpAddr: b,
146+
buffer: payload,
147+
opts: []MulticastOption{WithAnyMulticastInterface()},
148148
},
149149
},
150150
{
151151
name: "valid first interface",
152152
args: args{
153-
ctx: context.Background(),
154-
udpCtx: b,
155-
buffer: payload,
156-
opts: []MulticastOption{WithMulticastInterface(iface)},
153+
ctx: context.Background(),
154+
udpAddr: b,
155+
buffer: payload,
156+
opts: []MulticastOption{WithMulticastInterface(iface)},
157157
},
158158
},
159159
{
160160
name: "cancelled",
161161
args: args{
162-
ctx: ctxCanceled,
163-
udpCtx: b,
164-
buffer: payload,
162+
ctx: ctxCanceled,
163+
udpAddr: b,
164+
buffer: payload,
165165
},
166166
wantErr: true,
167167
},
@@ -207,7 +207,7 @@ func TestUDPConnwriteMulticastWithContext(t *testing.T) {
207207
wg.Add(1)
208208
go func() {
209209
b := make([]byte, 1024)
210-
n, _, _, errR := c2.ReadWithContext(ctx, b)
210+
n, _, errR := c2.ReadWithContext(ctx, b)
211211
assert.NoError(t, errR)
212212
if n > 0 {
213213
b = b[:n]
@@ -219,7 +219,7 @@ func TestUDPConnwriteMulticastWithContext(t *testing.T) {
219219

220220
for _, tt := range tests {
221221
t.Run(tt.name, func(t *testing.T) {
222-
err = c1.WriteMulticast(tt.args.ctx, tt.args.udpCtx, tt.args.buffer, tt.args.opts...)
222+
err = c1.WriteMulticast(tt.args.ctx, tt.args.udpAddr, tt.args.buffer, tt.args.opts...)
223223
c1.LocalAddr()
224224
c1.RemoteAddr()
225225

udp/client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -758,7 +758,7 @@ func TestClientKeepAliveMonitor(t *testing.T) {
758758
go func() {
759759
defer serverWg.Done()
760760
for {
761-
_, _, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024))
761+
_, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024))
762762
if errR != nil {
763763
if errors.Is(errR, net.ErrClosed) {
764764
return

udp/server/discover.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (s *Server) DiscoveryRequest(req *pool.Message, address string, receiverFun
7575
return err
7676
}
7777
} else {
78-
err = c.WriteWithContext(req.Context(), addr, nil, data)
78+
err = c.WriteWithContext(req.Context(), addr, data)
7979
if err != nil {
8080
return err
8181
}

udp/server/server.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ func (s *Server) Serve(l *coapNet.UDPConn) error {
145145

146146
for {
147147
buf := m
148-
n, cm, raddr, err := l.ReadWithContext(s.ctx, buf)
148+
var raddr *net.UDPAddr
149+
var cm *coapNet.ControlMessage
150+
n, err := l.ReadWithOptions(buf, coapNet.WithContext(s.ctx), coapNet.WithGetControlMessage(&cm), coapNet.WithGetRemoteAddr(&raddr))
149151
if err != nil {
150152
wg.Wait()
151153

0 commit comments

Comments
 (0)