diff --git a/dtls/server/session.go b/dtls/server/session.go index 90061a51..5f0d9154 100644 --- a/dtls/server/session.go +++ b/dtls/server/session.go @@ -146,7 +146,7 @@ func (s *Session) Run(cc *client.Conn) (err error) { return fmt.Errorf("cannot read from connection: %w", err) } readBuf = readBuf[:readLen] - err = cc.Process(readBuf) + err = cc.Process(nil, readBuf) if err != nil { return err } diff --git a/message/pool/message.go b/message/pool/message.go index 22cf8865..04fcc58c 100644 --- a/message/pool/message.go +++ b/message/pool/message.go @@ -10,6 +10,7 @@ import ( multierror "github.com/hashicorp/go-multierror" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/net" "go.uber.org/atomic" ) @@ -26,6 +27,7 @@ type Message struct { // Context context of request. ctx context.Context msg message.Message + controlMessage *net.ControlMessage // control message for UDP hijacked atomic.Bool isModified bool valueBuffer []byte @@ -73,6 +75,22 @@ func (r *Message) SetMessage(message message.Message) { r.isModified = true } +func (r *Message) SetControlMessage(cm *net.ControlMessage) { + r.controlMessage = cm +} + +func (r *Message) ControlMessage() *net.ControlMessage { + return r.controlMessage +} + +// UpsertControlMessage set value only when origin value is not set. +func (r *Message) UpsertControlMessage(cm *net.ControlMessage) { + if r.controlMessage != nil { + return + } + r.SetControlMessage(cm) +} + // SetMessageID only 0 to 2^16-1 are valid. func (r *Message) SetMessageID(mid int32) { r.msg.MessageID = mid @@ -120,6 +138,7 @@ func (r *Message) Reset() { r.valueBuffer = r.origValueBuffer r.body = nil r.isModified = false + r.controlMessage = nil if cap(r.bufferMarshal) > 1024 { r.bufferMarshal = make([]byte, 256) } @@ -568,6 +587,7 @@ func (r *Message) Clone(msg *Message) error { msg.ResetOptionsTo(r.Options()) msg.SetType(r.Type()) msg.SetMessageID(r.MessageID()) + msg.SetControlMessage(r.ControlMessage()) if r.Body() != nil { buf := bytes.NewBuffer(nil) diff --git a/net/connUDP.go b/net/connUDP.go index f749b22d..9a565682 100644 --- a/net/connUDP.go +++ b/net/connUDP.go @@ -24,10 +24,36 @@ type UDPConn struct { } type ControlMessage struct { + Dst net.IP // destination address, receiving only Src net.IP // source address, specifying only IfIndex int // interface index, must be 1 <= value when specifying } +func (c *ControlMessage) String() string { + if c == nil { + return "" + } + var sb strings.Builder + if c.Dst != nil { + sb.WriteString(fmt.Sprintf("Dst: %s, ", c.Dst)) + } + if c.Src != nil { + sb.WriteString(fmt.Sprintf("Src: %s, ", c.Src)) + } + if c.IfIndex >= 1 { + sb.WriteString(fmt.Sprintf("IfIndex: %d, ", c.IfIndex)) + } + return sb.String() +} + +// GetIfIndex returns the interface index of the network interface. 0 means no interface index specified. +func (c *ControlMessage) GetIfIndex() int { + if c == nil { + return 0 + } + return c.IfIndex +} + type packetConn interface { SetWriteDeadline(t time.Time) error WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) @@ -36,14 +62,18 @@ type packetConn interface { SetMulticastLoopback(on bool) error JoinGroup(ifi *net.Interface, group net.Addr) error LeaveGroup(ifi *net.Interface, group net.Addr) error + ReadFrom(b []byte) (n int, cm *ControlMessage, src net.Addr, err error) } type packetConnIPv4 struct { packetConnIPv4 *ipv4.PacketConn } -func newPacketConnIPv4(p *ipv4.PacketConn) *packetConnIPv4 { - return &packetConnIPv4{p} +func newPacketConnIPv4(p *ipv4.PacketConn) (*packetConnIPv4, error) { + if err := p.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface|ipv4.FlagSrc, true); err != nil { + return nil, err + } + return &packetConnIPv4{p}, nil } func (p *packetConnIPv4) SetMulticastInterface(ifi *net.Interface) error { @@ -65,6 +95,18 @@ func (p *packetConnIPv4) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n return p.packetConnIPv4.WriteTo(b, c, dst) } +func (p *packetConnIPv4) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) { + n, cm, src, err := p.packetConnIPv4.ReadFrom(b) + if err != nil { + return -1, nil, nil, err + } + return n, &ControlMessage{ + Dst: cm.Dst, + Src: cm.Src, + IfIndex: cm.IfIndex, + }, src, err +} + func (p *packetConnIPv4) SetMulticastHopLimit(hoplim int) error { return p.packetConnIPv4.SetMulticastTTL(hoplim) } @@ -85,8 +127,11 @@ type packetConnIPv6 struct { packetConnIPv6 *ipv6.PacketConn } -func newPacketConnIPv6(p *ipv6.PacketConn) *packetConnIPv6 { - return &packetConnIPv6{p} +func newPacketConnIPv6(p *ipv6.PacketConn) (*packetConnIPv6, error) { + if err := p.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface|ipv6.FlagSrc, true); err != nil { + return nil, err + } + return &packetConnIPv6{p}, nil } func (p *packetConnIPv6) SetMulticastInterface(ifi *net.Interface) error { @@ -97,6 +142,18 @@ func (p *packetConnIPv6) SetWriteDeadline(t time.Time) error { return p.packetConnIPv6.SetWriteDeadline(t) } +func (p *packetConnIPv6) ReadFrom(b []byte) (int, *ControlMessage, net.Addr, error) { + n, cm, src, err := p.packetConnIPv6.ReadFrom(b) + if err != nil { + return -1, nil, nil, err + } + return n, &ControlMessage{ + Dst: cm.Dst, + Src: cm.Src, + IfIndex: cm.IfIndex, + }, src, err +} + func (p *packetConnIPv6) WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) { var c *ipv6.ControlMessage if cm != nil { @@ -124,10 +181,6 @@ func (p *packetConnIPv6) LeaveGroup(ifi *net.Interface, group net.Addr) error { return p.packetConnIPv6.LeaveGroup(ifi, group) } -func (p *packetConnIPv6) SetControlMessage(on bool) error { - return p.packetConnIPv6.SetMulticastLoopback(on) -} - // IsIPv6 return's true if addr is IPV6. func IsIPv6(addr net.IP) bool { if ip := addr.To16(); ip != nil && ip.To4() == nil { @@ -174,10 +227,17 @@ func NewUDPConn(network string, c *net.UDPConn, opts ...UDPOption) *UDPConn { panic(fmt.Errorf("invalid address type(%T), UDP address expected", laddr)) } var pc packetConn + var err error if IsIPv6(addr.IP) { - pc = newPacketConnIPv6(ipv6.NewPacketConn(c)) + pc, err = newPacketConnIPv6(ipv6.NewPacketConn(c)) + if err != nil { + panic(fmt.Errorf("invalid UDP connection: %w", err)) + } } else { - pc = newPacketConnIPv4(ipv4.NewPacketConn(c)) + pc, err = newPacketConnIPv4(ipv4.NewPacketConn(c)) + if err != nil { + panic(fmt.Errorf("invalid UDP connection: %w", err)) + } } return &UDPConn{ @@ -214,11 +274,18 @@ func (c *UDPConn) Close() error { func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLimit int, raddr *net.UDPAddr, buffer []byte) error { var pktSrc net.IP var p packetConn + var err error if IsIPv6(raddr.IP) { - p = newPacketConnIPv6(ipv6.NewPacketConn(c.connection)) + p, err = newPacketConnIPv6(ipv6.NewPacketConn(c.connection)) + if err != nil { + return err + } pktSrc = net.IPv6zero } else { - p = newPacketConnIPv4(ipv4.NewPacketConn(c.connection)) + p, err = newPacketConnIPv4(ipv4.NewPacketConn(c.connection)) + if err != nil { + return err + } pktSrc = net.IPv4zero } if src != nil { @@ -229,15 +296,14 @@ func (c *UDPConn) writeToAddr(iface *net.Interface, src *net.IP, multicastHopLim return ErrConnectionIsClosed } if iface != nil { - if err := p.SetMulticastInterface(iface); err != nil { + if err = p.SetMulticastInterface(iface); err != nil { return err } } - if err := p.SetMulticastHopLimit(multicastHopLimit); err != nil { + if err = p.SetMulticastHopLimit(multicastHopLimit); err != nil { return err } - var err error if iface != nil || src != nil { _, err = p.WriteTo(buffer, &ControlMessage{ Src: pktSrc, @@ -409,7 +475,7 @@ func (c *UDPConn) writeMulticast(ctx context.Context, raddr *net.UDPAddr, buffer } // WriteWithContext writes data with context. -func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buffer []byte) error { +func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, cm *ControlMessage, buffer []byte) error { if raddr == nil { return fmt.Errorf("cannot write with context: invalid raddr") } @@ -422,7 +488,7 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff if c.closed.Load() { return ErrConnectionIsClosed } - n, err := WriteToUDP(c.connection, raddr, buffer) + n, err := c.packetConn.WriteTo(buffer, cm, raddr) if err != nil { return err } @@ -434,20 +500,23 @@ func (c *UDPConn) WriteWithContext(ctx context.Context, raddr *net.UDPAddr, buff } // ReadWithContext reads packet with context. -func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *net.UDPAddr, error) { +func (c *UDPConn) ReadWithContext(ctx context.Context, buffer []byte) (int, *ControlMessage, *net.UDPAddr, error) { select { case <-ctx.Done(): - return -1, nil, ctx.Err() + return -1, nil, nil, ctx.Err() default: } if c.closed.Load() { - return -1, nil, ErrConnectionIsClosed + return -1, nil, nil, ErrConnectionIsClosed } - n, s, err := c.connection.ReadFromUDP(buffer) + n, cm, srcAddr, err := c.packetConn.ReadFrom(buffer) if err != nil { - return -1, nil, fmt.Errorf("cannot read from udp connection: %w", err) + return -1, nil, nil, fmt.Errorf("cannot read from udp connection: %w", err) + } + if udpAdrr, ok := srcAddr.(*net.UDPAddr); ok { + return n, cm, udpAdrr, nil } - return n, s, err + return -1, nil, nil, fmt.Errorf("cannot read from udp connection: invalid srcAddr type %T", srcAddr) } // SetMulticastLoopback sets whether transmitted multicast packets diff --git a/net/connUDP_test.go b/net/connUDP_test.go index 5cfe3c8d..78b30d45 100644 --- a/net/connUDP_test.go +++ b/net/connUDP_test.go @@ -70,7 +70,7 @@ func TestUDPConnWriteWithContext(t *testing.T) { go func() { b := make([]byte, 1024) - _, _, errR := c2.ReadWithContext(ctx, b) + _, _, _, errR := c2.ReadWithContext(ctx, b) if errR != nil { return } @@ -78,7 +78,7 @@ func TestUDPConnWriteWithContext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err = c1.WriteWithContext(tt.args.ctx, tt.args.udpCtx, tt.args.buffer) + err = c1.WriteWithContext(tt.args.ctx, tt.args.udpCtx, nil, tt.args.buffer) c1.LocalAddr() c1.RemoteAddr() @@ -201,7 +201,7 @@ func TestUDPConnwriteMulticastWithContext(t *testing.T) { wg.Add(1) go func() { b := make([]byte, 1024) - n, _, errR := c2.ReadWithContext(ctx, b) + n, _, _, errR := c2.ReadWithContext(ctx, b) assert.NoError(t, errR) if n > 0 { b = b[:n] diff --git a/udp/client/conn.go b/udp/client/conn.go index 8e20d871..a164c790 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -716,6 +716,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con }() resp := cc.AcquireMessage(cc.Context()) resp.SetToken(req.Token()) + ifIndex := req.ControlMessage().GetIfIndex() w := responsewriter.New(resp, cc, req.Options()...) defer func() { cc.ReleaseMessage(w.Message()) @@ -730,6 +731,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con // nothing to send return } + upsertInterfaceToMessage(w.Message(), ifIndex) errW := cc.writeMessageAsync(w.Message()) if errW != nil { cc.closeConnection() @@ -741,6 +743,15 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess cc.sendPong(w, r) } +func upsertInterfaceToMessage(m *pool.Message, ifIndex int) { + if ifIndex >= 1 { + cm := coapNet.ControlMessage{ + IfIndex: ifIndex, + } + m.UpsertControlMessage(&cm) + } +} + func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { // ping request if r.Code() == codes.Empty && r.Type() == message.Confirmable && len(r.Token()) == 0 && len(r.Options()) == 0 && r.Body() == nil { @@ -752,6 +763,7 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { elem.ReleaseMessage(cc) resp := cc.AcquireMessage(cc.Context()) resp.SetToken(r.Token()) + upsertInterfaceToMessage(resp, r.ControlMessage().GetIfIndex()) w := responsewriter.New(resp, cc, r.Options()...) defer func() { cc.ReleaseMessage(w.Message()) @@ -769,7 +781,7 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { return false } -func (cc *Conn) Process(datagram []byte) error { +func (cc *Conn) Process(cm *coapNet.ControlMessage, datagram []byte) error { if uint32(len(datagram)) > cc.session.MaxMessageSize() { return fmt.Errorf("max message size(%v) was exceeded %v", cc.session.MaxMessageSize(), len(datagram)) } @@ -779,6 +791,7 @@ func (cc *Conn) Process(datagram []byte) error { cc.ReleaseMessage(req) return err } + req.SetControlMessage(cm) req.SetSequence(cc.Sequence()) cc.checkMyMessageID(req) cc.inactivityMonitor.Notify() diff --git a/udp/client_test.go b/udp/client_test.go index a3f5b7dc..5781c7a3 100644 --- a/udp/client_test.go +++ b/udp/client_test.go @@ -758,7 +758,7 @@ func TestClientKeepAliveMonitor(t *testing.T) { go func() { defer serverWg.Done() for { - _, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024)) + _, _, _, errR := ld.ReadWithContext(ctx, make([]byte, 1024)) if errR != nil { if errors.Is(errR, net.ErrClosed) { return diff --git a/udp/server/discover.go b/udp/server/discover.go index aa41af4c..86f81980 100644 --- a/udp/server/discover.go +++ b/udp/server/discover.go @@ -75,7 +75,7 @@ func (s *Server) DiscoveryRequest(req *pool.Message, address string, receiverFun return err } } else { - err = c.WriteWithContext(req.Context(), addr, data) + err = c.WriteWithContext(req.Context(), addr, nil, data) if err != nil { return err } diff --git a/udp/server/server.go b/udp/server/server.go index 8bceeab5..0f512699 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -145,7 +145,7 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { for { buf := m - n, raddr, err := l.ReadWithContext(s.ctx, buf) + n, cm, raddr, err := l.ReadWithContext(s.ctx, buf) if err != nil { wg.Wait() @@ -165,7 +165,7 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) continue } - err = cc.Process(buf) + err = cc.Process(cm, buf) if err != nil { s.closeConnection(cc) s.cfg.Errors(fmt.Errorf("%v: cannot process packet: %w", cc.RemoteAddr(), err)) diff --git a/udp/server/session.go b/udp/server/session.go index 870ff53c..103deb6d 100644 --- a/udp/server/session.go +++ b/udp/server/session.go @@ -109,7 +109,7 @@ func (s *Session) WriteMessage(req *pool.Message) error { if err != nil { return fmt.Errorf("cannot marshal: %w", err) } - return s.connection.WriteWithContext(req.Context(), s.raddr, data) + return s.connection.WriteWithContext(req.Context(), s.raddr, req.ControlMessage(), data) } // WriteMulticastMessage sends multicast to the remote multicast address. @@ -135,12 +135,12 @@ func (s *Session) Run(cc *client.Conn) (err error) { m := make([]byte, s.mtu) for { buf := m - n, _, err := s.connection.ReadWithContext(s.Context(), buf) + n, cm, _, err := s.connection.ReadWithContext(s.Context(), buf) if err != nil { return err } buf = buf[:n] - err = cc.Process(buf) + err = cc.Process(cm, buf) if err != nil { return err }