From 859785d1cead8ac153c9ae3dcfa9947e57dc83a4 Mon Sep 17 00:00:00 2001 From: database64128 Date: Fri, 18 Oct 2024 16:11:14 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=9E=20conn,=20service:=20improve=20mms?= =?UTF-8?q?g=20api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Return early to indicate error, like the underlying syscall does. It no longer drops send errors on the floor, and allows more accurate stats. --- conn/conn_mmsg.go | 9 +- conn/mmsg.go | 119 ++++++++++-------- ...sysnum_mmsg_generic.go => mmsg_generic.go} | 0 ...sysnum_mmsg_linux32.go => mmsg_linux32.go} | 0 conn/syscall_mmsg.go | 26 ---- conn/ztypes_mmsg.go | 10 -- service/udp_nat.go | 2 - service/udp_nat_mmsg.go | 73 ++++++----- service/udp_session.go | 2 - service/udp_session_mmsg.go | 81 ++++++------ service/udp_transparent_linux.go | 86 +++++++------ 11 files changed, 208 insertions(+), 200 deletions(-) rename conn/{zsysnum_mmsg_generic.go => mmsg_generic.go} (100%) rename conn/{zsysnum_mmsg_linux32.go => mmsg_linux32.go} (100%) delete mode 100644 conn/syscall_mmsg.go delete mode 100644 conn/ztypes_mmsg.go diff --git a/conn/conn_mmsg.go b/conn/conn_mmsg.go index 52c3bdf..19bf8dc 100644 --- a/conn/conn_mmsg.go +++ b/conn/conn_mmsg.go @@ -7,15 +7,16 @@ import ( "net" ) -// ListenUDPRawConn is like [ListenUDP] but wraps the [*net.UDPConn] in a [rawUDPConn] for batch I/O. -func (lc *ListenConfig) ListenUDPRawConn(ctx context.Context, network, address string) (c rawUDPConn, info SocketInfo, err error) { +// ListenUDPMmsgConn is like [ListenUDP] but wraps the [*net.UDPConn] in a [MmsgConn] for +// reading and writing multiple messages using the recvmmsg(2) and sendmmsg(2) system calls. +func (lc *ListenConfig) ListenUDPMmsgConn(ctx context.Context, network, address string) (c MmsgConn, info SocketInfo, err error) { info.MaxUDPGSOSegments = 1 nlc := lc.tlc.ListenConfig nlc.Control = lc.fns.controlFunc(&info) pc, err := nlc.ListenPacket(ctx, network, address) if err != nil { - return rawUDPConn{}, info, err + return MmsgConn{}, info, err } - c, err = NewRawUDPConn(pc.(*net.UDPConn)) + c, err = NewMmsgConn(pc.(*net.UDPConn)) return c, info, err } diff --git a/conn/mmsg.go b/conn/mmsg.go index 678732b..dcbb78a 100644 --- a/conn/mmsg.go +++ b/conn/mmsg.go @@ -6,33 +6,37 @@ import ( "net" "os" "syscall" + "unsafe" + + "golang.org/x/sys/unix" ) -type rawUDPConn struct { +// MmsgConn wraps a [*net.UDPConn] and provides methods for reading and writing +// multiple messages using the recvmmsg(2) and sendmmsg(2) system calls. +type MmsgConn struct { *net.UDPConn rawConn syscall.RawConn } -// NewRawUDPConn wraps a [net.UDPConn] in a [rawUDPConn] for batch I/O. -func NewRawUDPConn(udpConn *net.UDPConn) (rawUDPConn, error) { +// NewMmsgConn returns a new [MmsgConn] for udpConn. +func NewMmsgConn(udpConn *net.UDPConn) (MmsgConn, error) { rawConn, err := udpConn.SyscallConn() if err != nil { - return rawUDPConn{}, err + return MmsgConn{}, err } - return rawUDPConn{ + return MmsgConn{ UDPConn: udpConn, rawConn: rawConn, }, nil } -// MmsgRConn wraps a [net.UDPConn] and provides the [ReadMsgs] method -// for reading multiple messages in a single recvmmsg(2) system call. +// MmsgRConn provides read access to the [MmsgConn]. // -// [MmsgRConn] is not safe for concurrent use. -// Use the [RConn] method to create a new [MmsgRConn] instance for each goroutine. +// MmsgRConn is not safe for concurrent use. +// Always create a new MmsgRConn for each goroutine. type MmsgRConn struct { - rawUDPConn + MmsgConn rawReadFunc func(fd uintptr) (done bool) readMsgvec []Mmsghdr readFlags int @@ -40,66 +44,66 @@ type MmsgRConn struct { readErr error } -// MmsgWConn wraps a [net.UDPConn] and provides the [WriteMsgs] method -// for writing multiple messages in a single sendmmsg(2) system call. +// MmsgWConn provides write access to the [MmsgConn]. // -// [MmsgWConn] is not safe for concurrent use. -// Use the [WConn] method to create a new [MmsgWConn] instance for each goroutine. +// MmsgWConn is not safe for concurrent use. +// Always create a new MmsgWConn for each goroutine. type MmsgWConn struct { - rawUDPConn + MmsgConn rawWriteFunc func(fd uintptr) (done bool) writeMsgvec []Mmsghdr writeFlags int - writeErrno syscall.Errno + writeN int + writeErr error } -// RConn returns a new [MmsgRConn] instance for batch reading. -func (c rawUDPConn) RConn() *MmsgRConn { - mmsgRConn := MmsgRConn{ - rawUDPConn: c, +// NewRConn returns the connection wrapped in a new [*MmsgRConn] for batch reading. +func (c MmsgConn) NewRConn() *MmsgRConn { + rc := MmsgRConn{ + MmsgConn: c, } - mmsgRConn.rawReadFunc = func(fd uintptr) (done bool) { + rc.rawReadFunc = func(fd uintptr) (done bool) { var errno syscall.Errno - mmsgRConn.readN, errno = recvmmsg(int(fd), mmsgRConn.readMsgvec, mmsgRConn.readFlags) + rc.readN, errno = recvmmsg(int(fd), rc.readMsgvec, rc.readFlags) switch errno { case 0: + rc.readErr = nil case syscall.EAGAIN: return false default: - mmsgRConn.readErr = os.NewSyscallError("recvmmsg", errno) + rc.readErr = os.NewSyscallError("recvmmsg", errno) } return true } - return &mmsgRConn + return &rc } -// WConn returns a new [MmsgWConn] instance for batch writing. -func (c rawUDPConn) WConn() *MmsgWConn { - mmsgWConn := MmsgWConn{ - rawUDPConn: c, +// NewWConn returns the connection wrapped in a new [*MmsgWConn] for batch writing. +func (c MmsgConn) NewWConn() *MmsgWConn { + wc := MmsgWConn{ + MmsgConn: c, } - mmsgWConn.rawWriteFunc = func(fd uintptr) (done bool) { + wc.rawWriteFunc = func(fd uintptr) (done bool) { + wc.writeN = 0 for { - n, errno := sendmmsg(int(fd), mmsgWConn.writeMsgvec, mmsgWConn.writeFlags) + n, errno := sendmmsg(int(fd), wc.writeMsgvec, wc.writeFlags) switch errno { case 0: case syscall.EAGAIN: return false default: - mmsgWConn.writeErrno = errno - mmsgWConn.writeMsgvec = mmsgWConn.writeMsgvec[1:] - if len(mmsgWConn.writeMsgvec) == 0 { - return true - } - continue + wc.writeErr = os.NewSyscallError("sendmmsg", errno) + return true } - mmsgWConn.writeMsgvec = mmsgWConn.writeMsgvec[n:] + wc.writeMsgvec = wc.writeMsgvec[n:] + wc.writeN += n - if len(mmsgWConn.writeMsgvec) == 0 { + if len(wc.writeMsgvec) == 0 { + wc.writeErr = nil return true } @@ -120,32 +124,49 @@ func (c rawUDPConn) WConn() *MmsgWConn { } } - return &mmsgWConn + return &wc } -// ReadMsgs reads as many messages as possible into the given msgvec +// ReadMsgs reads as many messages as possible into msgvec // and returns the number of messages read or an error. func (c *MmsgRConn) ReadMsgs(msgvec []Mmsghdr, flags int) (int, error) { c.readMsgvec = msgvec c.readFlags = flags - c.readN = 0 - c.readErr = nil if err := c.rawConn.Read(c.rawReadFunc); err != nil { return 0, err } return c.readN, c.readErr } -// WriteMsgs writes all messages in the given msgvec and returns the last encountered error. -func (c *MmsgWConn) WriteMsgs(msgvec []Mmsghdr, flags int) error { +// WriteMsgs writes the messages in msgvec to the connection. +// It returns the number of messages written as n, and if n < len(msgvec), +// the error from writing the n-th message. +func (c *MmsgWConn) WriteMsgs(msgvec []Mmsghdr, flags int) (int, error) { c.writeMsgvec = msgvec c.writeFlags = flags - c.writeErrno = 0 if err := c.rawConn.Write(c.rawWriteFunc); err != nil { - return err + return 0, err } - if c.writeErrno != 0 { - return os.NewSyscallError("sendmmsg", c.writeErrno) + return c.writeN, c.writeErr +} + +type Mmsghdr struct { + Msghdr unix.Msghdr + Msglen uint32 +} + +func mmsgSyscall(trap uintptr, fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) { + r0, _, e1 := unix.Syscall6(trap, uintptr(fd), uintptr(unsafe.Pointer(unsafe.SliceData(msgvec))), uintptr(len(msgvec)), uintptr(flags), 0, 0) + if e1 != 0 { + return 0, e1 } - return nil + return int(r0), 0 +} + +func recvmmsg(fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) { + return mmsgSyscall(SYS_RECVMMSG, fd, msgvec, flags) +} + +func sendmmsg(fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) { + return mmsgSyscall(unix.SYS_SENDMMSG, fd, msgvec, flags) } diff --git a/conn/zsysnum_mmsg_generic.go b/conn/mmsg_generic.go similarity index 100% rename from conn/zsysnum_mmsg_generic.go rename to conn/mmsg_generic.go diff --git a/conn/zsysnum_mmsg_linux32.go b/conn/mmsg_linux32.go similarity index 100% rename from conn/zsysnum_mmsg_linux32.go rename to conn/mmsg_linux32.go diff --git a/conn/syscall_mmsg.go b/conn/syscall_mmsg.go deleted file mode 100644 index 6ffc886..0000000 --- a/conn/syscall_mmsg.go +++ /dev/null @@ -1,26 +0,0 @@ -//go:build linux || netbsd - -package conn - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -func mmsgSyscall(trap uintptr, fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) { - r0, _, e1 := unix.Syscall6(trap, uintptr(fd), uintptr(unsafe.Pointer(unsafe.SliceData(msgvec))), uintptr(len(msgvec)), uintptr(flags), 0, 0) - if e1 != 0 { - return 0, e1 - } - return int(r0), 0 -} - -func recvmmsg(fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) { - return mmsgSyscall(SYS_RECVMMSG, fd, msgvec, flags) -} - -func sendmmsg(fd int, msgvec []Mmsghdr, flags int) (int, syscall.Errno) { - return mmsgSyscall(unix.SYS_SENDMMSG, fd, msgvec, flags) -} diff --git a/conn/ztypes_mmsg.go b/conn/ztypes_mmsg.go deleted file mode 100644 index 8d2f1ab..0000000 --- a/conn/ztypes_mmsg.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build linux || netbsd - -package conn - -import "golang.org/x/sys/unix" - -type Mmsghdr struct { - Msghdr unix.Msghdr - Msglen uint32 -} diff --git a/service/udp_nat.go b/service/udp_nat.go index 20db25f..ae747dd 100644 --- a/service/udp_nat.go +++ b/service/udp_nat.go @@ -475,9 +475,7 @@ func (s *UDPNATRelay) relayServerConnToNatConnGeneric(ctx context.Context, uplin if err != nil { uplink.logger.Warn("Failed to set read deadline on natConn", zap.Stringer("clientAddress", uplink.clientAddrPort), - zap.Stringer("targetAddress", &queuedPacket.targetAddr), zap.String("client", uplink.clientName), - zap.Stringer("writeDestAddress", destAddrPort), zap.Duration("natTimeout", uplink.natTimeout), zap.Error(err), ) diff --git a/service/udp_nat_mmsg.go b/service/udp_nat_mmsg.go index e13258e..669880e 100644 --- a/service/udp_nat_mmsg.go +++ b/service/udp_nat_mmsg.go @@ -56,7 +56,7 @@ func (s *UDPNATRelay) start(ctx context.Context, index int, lnc *udpRelayServerC } func (s *UDPNATRelay) startMmsg(ctx context.Context, index int, lnc *udpRelayServerConn) error { - serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) + serverConn, _, err := lnc.listenConfig.ListenUDPMmsgConn(ctx, lnc.network, lnc.address) if err != nil { return err } @@ -71,7 +71,7 @@ func (s *UDPNATRelay) startMmsg(ctx context.Context, index int, lnc *udpRelaySer s.mwg.Add(1) go func() { - s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.RConn()) + s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.NewRConn()) s.mwg.Done() }() @@ -275,7 +275,7 @@ func (s *UDPNATRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRe return } - natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDPMmsgConn(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", clientAddrPort), @@ -335,7 +335,7 @@ func (s *UDPNATRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRe s.relayServerConnToNatConnSendmmsg(ctx, natUplinkMmsg{ clientName: clientInfo.Name, clientAddrPort: clientAddrPort, - natConn: natConn.WConn(), + natConn: natConn.NewWConn(), natConnSendCh: natConnSendCh, natConnPacker: clientSession.Packer, natTimeout: lnc.natTimeout, @@ -352,10 +352,10 @@ func (s *UDPNATRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRe clientAddrPort: clientAddrPort, clientPktinfop: clientPktinfop, clientPktinfo: &entry.clientPktinfo, - natConn: natConn.RConn(), + natConn: natConn.NewRConn(), natConnRecvBufSize: clientSession.MaxPacketSize, natConnUnpacker: clientSession.Unpacker, - serverConn: serverConn.WConn(), + serverConn: serverConn.NewWConn(), serverConnPacker: serverConnPacker, relayBatchSize: lnc.relayBatchSize, logger: lnc.logger, @@ -412,6 +412,7 @@ func (s *UDPNATRelay) relayServerConnToNatConnSendmmsg(ctx context.Context, upli ) qpvec := make([]*natQueuedPacket, uplink.relayBatchSize) + dapvec := make([]netip.AddrPort, uplink.relayBatchSize) namevec := make([]unix.RawSockaddrInet6, uplink.relayBatchSize) iovec := make([]unix.Iovec, uplink.relayBatchSize) msgvec := make([]conn.Mmsghdr, uplink.relayBatchSize) @@ -454,6 +455,7 @@ main: } qpvec[count] = queuedPacket + dapvec[count] = destAddrPort namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort) iovec[count].Base = &queuedPacket.buf[packetStart] iovec[count].SetLen(packetLength) @@ -475,31 +477,35 @@ main: } } - if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil { - uplink.logger.Warn("Failed to batch write packets to natConn", - zap.Stringer("clientAddress", uplink.clientAddrPort), - zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddr), - zap.String("client", uplink.clientName), - zap.Stringer("lastWriteDestAddress", destAddrPort), - zap.Error(err), - ) + for start := 0; start < count; { + n, err := uplink.natConn.WriteMsgs(msgvec[start:count], 0) + start += n + if err != nil { + uplink.logger.Warn("Failed to batch write packets to natConn", + zap.Stringer("clientAddress", uplink.clientAddrPort), + zap.Stringer("targetAddress", &qpvec[start].targetAddr), + zap.String("client", uplink.clientName), + zap.Stringer("writeDestAddress", &dapvec[start]), + zap.Uint("packetLength", uint(iovec[start].Len)), + zap.Error(err), + ) + start++ + } + + sendmmsgCount++ + packetsSent += uint64(n) + burstBatchSize = max(burstBatchSize, n) } if err := uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)); err != nil { uplink.logger.Warn("Failed to set read deadline on natConn", zap.Stringer("clientAddress", uplink.clientAddrPort), - zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddr), zap.String("client", uplink.clientName), - zap.Stringer("lastWriteDestAddress", destAddrPort), zap.Duration("natTimeout", uplink.natTimeout), zap.Error(err), ) } - sendmmsgCount++ - packetsSent += uint64(count) - burstBatchSize = max(burstBatchSize, count) - qpvecn := qpvec[:count] for i := range qpvecn { @@ -659,18 +665,23 @@ func (s *UDPNATRelay) relayNatConnToServerConnSendmmsg(downlink natDownlinkMmsg) } } - err = downlink.serverConn.WriteMsgs(smsgvec[:ns], 0) - if err != nil { - downlink.logger.Warn("Failed to batch write packets to serverConn", - zap.Stringer("clientAddress", downlink.clientAddrPort), - zap.String("client", downlink.clientName), - zap.Error(err), - ) - } + for start := 0; start < ns; { + n, err := downlink.serverConn.WriteMsgs(smsgvec[start:ns], 0) + start += n + if err != nil { + downlink.logger.Warn("Failed to batch write packets to serverConn", + zap.Stringer("clientAddress", downlink.clientAddrPort), + zap.String("client", downlink.clientName), + zap.Uint("packetLength", uint(siovec[start].Len)), + zap.Error(err), + ) + start++ + } - sendmmsgCount++ - packetsSent += uint64(ns) - burstBatchSize = max(burstBatchSize, ns) + sendmmsgCount++ + packetsSent += uint64(n) + burstBatchSize = max(burstBatchSize, n) + } } downlink.logger.Info("Finished relay serverConn <- natConn", diff --git a/service/udp_session.go b/service/udp_session.go index a168e8e..6a00a02 100644 --- a/service/udp_session.go +++ b/service/udp_session.go @@ -548,9 +548,7 @@ func (s *UDPSessionRelay) relayServerConnToNatConnGeneric(ctx context.Context, u zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), zap.String("username", uplink.username), zap.Uint64("clientSessionID", uplink.csid), - zap.Stringer("targetAddress", &queuedPacket.targetAddr), zap.String("client", uplink.clientName), - zap.Stringer("writeDestAddress", destAddrPort), zap.Duration("natTimeout", uplink.natTimeout), zap.Error(err), ) diff --git a/service/udp_session_mmsg.go b/service/udp_session_mmsg.go index 85b6eb3..709b4c6 100644 --- a/service/udp_session_mmsg.go +++ b/service/udp_session_mmsg.go @@ -58,7 +58,7 @@ func (s *UDPSessionRelay) start(ctx context.Context, index int, lnc *udpRelaySer } func (s *UDPSessionRelay) startMmsg(ctx context.Context, index int, lnc *udpRelayServerConn) error { - serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) + serverConn, _, err := lnc.listenConfig.ListenUDPMmsgConn(ctx, lnc.network, lnc.address) if err != nil { return err } @@ -73,7 +73,7 @@ func (s *UDPSessionRelay) startMmsg(ctx context.Context, index int, lnc *udpRela s.mwg.Add(1) go func() { - s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.RConn()) + s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.NewRConn()) s.mwg.Done() }() @@ -316,7 +316,7 @@ func (s *UDPSessionRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *u return } - natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDPMmsgConn(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), @@ -384,7 +384,7 @@ func (s *UDPSessionRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *u s.relayServerConnToNatConnSendmmsg(ctx, sessionUplinkMmsg{ csid: csid, clientName: clientInfo.Name, - natConn: natConn.WConn(), + natConn: natConn.NewWConn(), natConnSendCh: natConnSendCh, natConnPacker: clientSession.Packer, natTimeout: lnc.natTimeout, @@ -402,10 +402,10 @@ func (s *UDPSessionRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *u clientName: clientInfo.Name, clientAddrInfop: clientAddrInfop, clientAddrInfo: &entry.clientAddrInfo, - natConn: natConn.RConn(), + natConn: natConn.NewRConn(), natConnRecvBufSize: clientSession.MaxPacketSize, natConnUnpacker: clientSession.Unpacker, - serverConn: serverConn.WConn(), + serverConn: serverConn.NewWConn(), serverConnPacker: serverConnPacker, username: entry.username, relayBatchSize: lnc.relayBatchSize, @@ -467,6 +467,7 @@ func (s *UDPSessionRelay) relayServerConnToNatConnSendmmsg(ctx context.Context, ) qpvec := make([]*sessionQueuedPacket, uplink.relayBatchSize) + dapvec := make([]netip.AddrPort, uplink.relayBatchSize) namevec := make([]unix.RawSockaddrInet6, uplink.relayBatchSize) iovec := make([]unix.Iovec, uplink.relayBatchSize) msgvec := make([]conn.Mmsghdr, uplink.relayBatchSize) @@ -511,6 +512,7 @@ main: } qpvec[count] = queuedPacket + dapvec[count] = destAddrPort namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort) iovec[count].Base = &queuedPacket.buf[packetStart] iovec[count].SetLen(packetLength) @@ -532,16 +534,26 @@ main: } } - if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil { - uplink.logger.Warn("Failed to batch write packets to natConn", - zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), - zap.String("username", uplink.username), - zap.Uint64("clientSessionID", uplink.csid), - zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddr), - zap.String("client", uplink.clientName), - zap.Stringer("lastWriteDestAddress", destAddrPort), - zap.Error(err), - ) + for start := 0; start < count; { + n, err := uplink.natConn.WriteMsgs(msgvec[start:count], 0) + start += n + if err != nil { + uplink.logger.Warn("Failed to batch write packets to natConn", + zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), + zap.String("username", uplink.username), + zap.Uint64("clientSessionID", uplink.csid), + zap.Stringer("targetAddress", &qpvec[start].targetAddr), + zap.String("client", uplink.clientName), + zap.Stringer("writeDestAddress", &dapvec[start]), + zap.Uint("packetLength", uint(iovec[start].Len)), + zap.Error(err), + ) + start++ + } + + sendmmsgCount++ + packetsSent += uint64(n) + burstBatchSize = max(burstBatchSize, n) } if err := uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)); err != nil { @@ -549,18 +561,12 @@ main: zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), zap.String("username", uplink.username), zap.Uint64("clientSessionID", uplink.csid), - zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddr), zap.String("client", uplink.clientName), - zap.Stringer("lastWriteDestAddress", destAddrPort), zap.Duration("natTimeout", uplink.natTimeout), zap.Error(err), ) } - sendmmsgCount++ - packetsSent += uint64(count) - burstBatchSize = max(burstBatchSize, count) - qpvecn := qpvec[:count] for i := range qpvecn { @@ -735,20 +741,25 @@ func (s *UDPSessionRelay) relayNatConnToServerConnSendmmsg(downlink sessionDownl continue } - err = downlink.serverConn.WriteMsgs(smsgvec[:ns], 0) - if err != nil { - downlink.logger.Warn("Failed to batch write packets to serverConn", - zap.Stringer("clientAddress", clientAddrPort), - zap.String("username", downlink.username), - zap.Uint64("clientSessionID", downlink.csid), - zap.String("client", downlink.clientName), - zap.Error(err), - ) - } + for start := 0; start < ns; { + n, err := downlink.serverConn.WriteMsgs(smsgvec[start:ns], 0) + start += n + if err != nil { + downlink.logger.Warn("Failed to batch write packets to serverConn", + zap.Stringer("clientAddress", clientAddrPort), + zap.String("username", downlink.username), + zap.Uint64("clientSessionID", downlink.csid), + zap.String("client", downlink.clientName), + zap.Uint("packetLength", uint(siovec[start].Len)), + zap.Error(err), + ) + start++ + } - sendmmsgCount++ - packetsSent += uint64(ns) - burstBatchSize = max(burstBatchSize, ns) + sendmmsgCount++ + packetsSent += uint64(n) + burstBatchSize = max(burstBatchSize, n) + } } downlink.logger.Info("Finished relay serverConn <- natConn", diff --git a/service/udp_transparent_linux.go b/service/udp_transparent_linux.go index ab9cf3b..4bbdcdc 100644 --- a/service/udp_transparent_linux.go +++ b/service/udp_transparent_linux.go @@ -128,7 +128,7 @@ func (s *UDPTransparentRelay) Start(ctx context.Context) error { index := i lnc := &s.listeners[index] - serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) + serverConn, _, err := lnc.listenConfig.ListenUDPMmsgConn(ctx, lnc.network, lnc.address) if err != nil { return err } @@ -143,7 +143,7 @@ func (s *UDPTransparentRelay) Start(ctx context.Context) error { s.mwg.Add(1) go func() { - s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.RConn()) + s.recvFromServerConnRecvmmsg(ctx, lnc, serverConn.NewRConn()) s.mwg.Done() }() @@ -297,7 +297,7 @@ func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(ctx context.Context, ln return } - natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDPMmsgConn(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", clientAddrPort), @@ -344,7 +344,7 @@ func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(ctx context.Context, ln s.relayServerConnToNatConnSendmmsg(ctx, transparentUplink{ clientName: clientInfo.Name, clientAddrPort: clientAddrPort, - natConn: natConn.WConn(), + natConn: natConn.NewWConn(), natConnSendCh: natConnSendCh, natConnPacker: clientSession.Packer, natTimeout: lnc.natTimeout, @@ -359,7 +359,7 @@ func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(ctx context.Context, ln s.relayNatConnToTransparentConnSendmmsg(ctx, transparentDownlink{ clientName: clientInfo.Name, clientAddrPort: clientAddrPort, - natConn: natConn.RConn(), + natConn: natConn.NewRConn(), natConnRecvBufSize: clientSession.MaxPacketSize, natConnUnpacker: clientSession.Unpacker, relayBatchSize: lnc.relayBatchSize, @@ -416,6 +416,7 @@ func (s *UDPTransparentRelay) relayServerConnToNatConnSendmmsg(ctx context.Conte ) qpvec := make([]*transparentQueuedPacket, uplink.relayBatchSize) + dapvec := make([]netip.AddrPort, uplink.relayBatchSize) namevec := make([]unix.RawSockaddrInet6, uplink.relayBatchSize) iovec := make([]unix.Iovec, uplink.relayBatchSize) msgvec := make([]conn.Mmsghdr, uplink.relayBatchSize) @@ -458,6 +459,7 @@ main: } qpvec[count] = queuedPacket + dapvec[count] = destAddrPort namevec[count] = conn.AddrPortToSockaddrInet6(destAddrPort) iovec[count].Base = &queuedPacket.buf[packetStart] iovec[count].SetLen(packetLength) @@ -479,31 +481,35 @@ main: } } - if err := uplink.natConn.WriteMsgs(msgvec[:count], 0); err != nil { - uplink.logger.Warn("Failed to batch write packets to natConn", - zap.Stringer("clientAddress", uplink.clientAddrPort), - zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort), - zap.String("client", uplink.clientName), - zap.Stringer("lastWriteDestAddress", destAddrPort), - zap.Error(err), - ) + for start := 0; start < count; { + n, err := uplink.natConn.WriteMsgs(msgvec[start:count], 0) + start += n + if err != nil { + uplink.logger.Warn("Failed to batch write packets to natConn", + zap.Stringer("clientAddress", uplink.clientAddrPort), + zap.Stringer("targetAddress", &qpvec[start].targetAddrPort), + zap.String("client", uplink.clientName), + zap.Stringer("writeDestAddress", &dapvec[start]), + zap.Uint("packetLength", uint(iovec[start].Len)), + zap.Error(err), + ) + start++ + } + + sendmmsgCount++ + packetsSent += uint64(n) + burstBatchSize = max(burstBatchSize, n) } if err := uplink.natConn.SetReadDeadline(time.Now().Add(uplink.natTimeout)); err != nil { uplink.logger.Warn("Failed to set read deadline on natConn", zap.Stringer("clientAddress", uplink.clientAddrPort), - zap.Stringer("lastTargetAddress", &qpvec[count-1].targetAddrPort), zap.String("client", uplink.clientName), - zap.Stringer("lastWriteDestAddress", destAddrPort), zap.Duration("natTimeout", uplink.natTimeout), zap.Error(err), ) } - sendmmsgCount++ - packetsSent += uint64(count) - burstBatchSize = max(burstBatchSize, count) - qpvecn := qpvec[:count] for i := range qpvecn { @@ -546,7 +552,7 @@ type transparentConn struct { } func (s *UDPTransparentRelay) newTransparentConn(ctx context.Context, address string, relayBatchSize int, name *byte, namelen uint32) (*transparentConn, error) { - c, _, err := s.transparentConnListenConfig.ListenUDPRawConn(ctx, "udp", address) + c, _, err := s.transparentConnListenConfig.ListenUDPMmsgConn(ctx, "udp", address) if err != nil { return nil, err } @@ -562,7 +568,7 @@ func (s *UDPTransparentRelay) newTransparentConn(ctx context.Context, address st } return &transparentConn{ - mwc: c.WConn(), + mwc: c.NewWConn(), iovec: iovec, msgvec: msgvec, }, nil @@ -574,15 +580,6 @@ func (tc *transparentConn) putMsg(base *byte, length int) { tc.n++ } -func (tc *transparentConn) writeMsgvec() (sendmmsgCount, packetsSent int, err error) { - if tc.n == 0 { - return - } - packetsSent = tc.n - tc.n = 0 - return 1, packetsSent, tc.mwc.WriteMsgs(tc.msgvec[:packetsSent], 0) -} - func (tc *transparentConn) close() error { return tc.mwc.Close() } @@ -710,19 +707,26 @@ func (s *UDPTransparentRelay) relayNatConnToTransparentConnSendmmsg(ctx context. } for payloadSourceAddrPort, tc := range tcMap { - sc, ps, err := tc.writeMsgvec() - if err != nil { - downlink.logger.Warn("Failed to batch write packets to transparentConn", - zap.Stringer("clientAddress", downlink.clientAddrPort), - zap.String("client", downlink.clientName), - zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), - zap.Error(err), - ) + for start := 0; start < tc.n; { + n, err := tc.mwc.WriteMsgs(tc.msgvec[start:tc.n], 0) + start += n + if err != nil { + downlink.logger.Warn("Failed to batch write packets to transparentConn", + zap.Stringer("clientAddress", downlink.clientAddrPort), + zap.String("client", downlink.clientName), + zap.Stringer("payloadSourceAddress", payloadSourceAddrPort), + zap.Uint("packetLength", uint(tc.iovec[start].Len)), + zap.Error(err), + ) + start++ + } + + sendmmsgCount += uint64(n) + packetsSent += uint64(n) + burstBatchSize = max(burstBatchSize, n) } - sendmmsgCount += uint64(sc) - packetsSent += uint64(ps) - burstBatchSize = max(burstBatchSize, ps) + tc.n = 0 } }