diff --git a/conn/conn.go b/conn/conn.go index 356eef8..0c6046d 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -8,18 +8,31 @@ import ( "github.com/database64128/tfo-go/v2" ) -type setFunc = func(fd int, network string) error +// SocketInfo contains information about a socket. +type SocketInfo struct { + // MaxUDPGSOSegments is the maximum number of UDP GSO segments supported by the socket. + // + // If UDP GSO is not enabled on the socket, or the system does not support UDP GSO, the value is 1. + // + // The value is 0 if the socket is not a UDP socket. + MaxUDPGSOSegments int + + // UDPGenericReceiveOffload indicates whether UDP GRO is enabled on the socket. + UDPGenericReceiveOffload bool +} + +type setFunc = func(fd int, network string, info *SocketInfo) error type setFuncSlice []setFunc -func (fns setFuncSlice) controlContextFunc() func(ctx context.Context, network, address string, c syscall.RawConn) error { +func (fns setFuncSlice) controlContextFunc(info *SocketInfo) func(ctx context.Context, network, address string, c syscall.RawConn) error { if len(fns) == 0 { return nil } return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) { if cerr := c.Control(func(fd uintptr) { for _, fn := range fns { - if err = fn(int(fd), network); err != nil { + if err = fn(int(fd), network, info); err != nil { return } } @@ -30,14 +43,14 @@ func (fns setFuncSlice) controlContextFunc() func(ctx context.Context, network, } } -func (fns setFuncSlice) controlFunc() func(network, address string, c syscall.RawConn) error { +func (fns setFuncSlice) controlFunc(info *SocketInfo) func(network, address string, c syscall.RawConn) error { if len(fns) == 0 { return nil } return func(network, address string, c syscall.RawConn) (err error) { if cerr := c.Control(func(fd uintptr) { for _, fn := range fns { - if err = fn(int(fd), network); err != nil { + if err = fn(int(fd), network, info); err != nil { return } } @@ -48,25 +61,32 @@ func (fns setFuncSlice) controlFunc() func(network, address string, c syscall.Ra } } -// ListenConfig is [tfo.ListenConfig] but provides a subjectively nicer API. -type ListenConfig tfo.ListenConfig +// ListenConfig wraps a [tfo.ListenConfig] and provides a subjectively nicer API. +type ListenConfig struct { + tlc tfo.ListenConfig + fns setFuncSlice +} // ListenTCP wraps [tfo.ListenConfig.Listen] and returns a [*net.TCPListener] directly. -func (lc *ListenConfig) ListenTCP(ctx context.Context, network, address string) (*net.TCPListener, error) { - l, err := (*tfo.ListenConfig)(lc).Listen(ctx, network, address) +func (lc *ListenConfig) ListenTCP(ctx context.Context, network, address string) (tln *net.TCPListener, info SocketInfo, err error) { + tlc := lc.tlc + tlc.Control = lc.fns.controlFunc(&info) + ln, err := tlc.Listen(ctx, network, address) if err != nil { - return nil, err + return nil, info, err } - return l.(*net.TCPListener), nil + return ln.(*net.TCPListener), info, nil } // ListenUDP wraps [net.ListenConfig.ListenPacket] and returns a [*net.UDPConn] directly. -func (lc *ListenConfig) ListenUDP(ctx context.Context, network, address string) (*net.UDPConn, error) { - pc, err := lc.ListenConfig.ListenPacket(ctx, network, address) +func (lc *ListenConfig) ListenUDP(ctx context.Context, network, address string) (uc *net.UDPConn, info SocketInfo, err error) { + nlc := lc.tlc.ListenConfig + nlc.Control = lc.fns.controlFunc(&info) + pc, err := nlc.ListenPacket(ctx, network, address) if err != nil { - return nil, err + return nil, info, err } - return pc.(*net.UDPConn), nil + return pc.(*net.UDPConn), info, nil } // ListenerSocketOptions contains listener-specific socket options. @@ -164,17 +184,17 @@ type ListenerSocketOptions struct { ReceiveOriginalDestAddr bool } -// ListenConfig returns a [ListenConfig] with a control function that sets the socket options. +// ListenConfig returns a [ListenConfig] that sets the socket options. func (lso ListenerSocketOptions) ListenConfig() ListenConfig { lc := ListenConfig{ - ListenConfig: net.ListenConfig{ - Control: lso.buildSetFns().controlFunc(), + tlc: tfo.ListenConfig{ + Backlog: lso.TCPFastOpenBacklog, + DisableTFO: !lso.TCPFastOpen, + Fallback: lso.TCPFastOpenFallback, }, - Backlog: lso.TCPFastOpenBacklog, - DisableTFO: !lso.TCPFastOpen, - Fallback: lso.TCPFastOpenFallback, + fns: lso.buildSetFns(), } - lc.SetMultipathTCP(lso.MultipathTCP) + lc.tlc.SetMultipathTCP(lso.MultipathTCP) return lc } @@ -277,7 +297,7 @@ type DialerSocketOptions struct { func (dso DialerSocketOptions) Dialer() Dialer { d := Dialer{ Dialer: net.Dialer{ - ControlContext: dso.buildSetFns().controlContextFunc(), + ControlContext: dso.buildSetFns().controlContextFunc(nil), }, DisableTFO: !dso.TCPFastOpen, Fallback: dso.TCPFastOpenFallback, diff --git a/conn/conn_bufsize_posix.go b/conn/conn_bufsize_posix.go index f52e6aa..190ab09 100644 --- a/conn/conn_bufsize_posix.go +++ b/conn/conn_bufsize_posix.go @@ -4,7 +4,7 @@ package conn func (fns setFuncSlice) appendSetSendBufferSize(size int) setFuncSlice { if size > 0 { - return append(fns, func(fd int, _ string) error { + return append(fns, func(fd int, _ string, _ *SocketInfo) error { return setSendBufferSize(fd, size) }) } @@ -13,7 +13,7 @@ func (fns setFuncSlice) appendSetSendBufferSize(size int) setFuncSlice { func (fns setFuncSlice) appendSetRecvBufferSize(size int) setFuncSlice { if size > 0 { - return append(fns, func(fd int, _ string) error { + return append(fns, func(fd int, _ string, _ *SocketInfo) error { return setRecvBufferSize(fd, size) }) } diff --git a/conn/conn_darwinlinuxwindows.go b/conn/conn_darwinlinuxwindows.go index c63c513..b2f3cfd 100644 --- a/conn/conn_darwinlinuxwindows.go +++ b/conn/conn_darwinlinuxwindows.go @@ -4,7 +4,9 @@ package conn func (fns setFuncSlice) appendSetRecvPktinfoFunc(recvPktinfo bool) setFuncSlice { if recvPktinfo { - return append(fns, setRecvPktinfo) + return append(fns, func(fd int, network string, _ *SocketInfo) error { + return setRecvPktinfo(fd, network) + }) } return fns } diff --git a/conn/conn_flags_tclass.go b/conn/conn_flags_tclass.go index 0f84548..b4134e2 100644 --- a/conn/conn_flags_tclass.go +++ b/conn/conn_flags_tclass.go @@ -33,7 +33,7 @@ func ParseFlagsForError(flags int) error { func (fns setFuncSlice) appendSetTrafficClassFunc(trafficClass int) setFuncSlice { if trafficClass != 0 { - return append(fns, func(fd int, network string) error { + return append(fns, func(fd int, network string, _ *SocketInfo) error { return setTrafficClass(fd, network, trafficClass) }) } diff --git a/conn/conn_freebsdlinux.go b/conn/conn_freebsdlinux.go index 356d5eb..172cd0b 100644 --- a/conn/conn_freebsdlinux.go +++ b/conn/conn_freebsdlinux.go @@ -4,7 +4,7 @@ package conn func (fns setFuncSlice) appendSetFwmarkFunc(fwmark int) setFuncSlice { if fwmark != 0 { - return append(fns, func(fd int, network string) error { + return append(fns, func(fd int, network string, _ *SocketInfo) error { return setFwmark(fd, fwmark) }) } diff --git a/conn/conn_linux.go b/conn/conn_linux.go index 8f75f2a..0dc2975 100644 --- a/conn/conn_linux.go +++ b/conn/conn_linux.go @@ -67,8 +67,10 @@ func setTCPUserTimeout(fd, msecs int) error { return nil } -func setUDPGenericReceiveOffload(fd int) { - _ = unix.SetsockoptInt(fd, unix.IPPROTO_UDP, unix.UDP_GRO, 1) +func setUDPGenericReceiveOffload(fd int, info *SocketInfo) { + if err := unix.SetsockoptInt(fd, unix.IPPROTO_UDP, unix.UDP_GRO, 1); err == nil { + info.UDPGenericReceiveOffload = true + } } func setTransparent(fd int, network string) error { @@ -143,7 +145,7 @@ func setRecvOrigDstAddr(fd int, network string) error { func (fns setFuncSlice) appendSetTCPDeferAcceptFunc(deferAcceptSecs int) setFuncSlice { if deferAcceptSecs > 0 { - return append(fns, func(fd int, network string) error { + return append(fns, func(fd int, network string, _ *SocketInfo) error { return setTCPDeferAccept(fd, deferAcceptSecs) }) } @@ -152,7 +154,7 @@ func (fns setFuncSlice) appendSetTCPDeferAcceptFunc(deferAcceptSecs int) setFunc func (fns setFuncSlice) appendSetTCPUserTimeoutFunc(userTimeoutMsecs int) setFuncSlice { if userTimeoutMsecs > 0 { - return append(fns, func(fd int, network string) error { + return append(fns, func(fd int, network string, _ *SocketInfo) error { return setTCPUserTimeout(fd, userTimeoutMsecs) }) } @@ -161,14 +163,18 @@ func (fns setFuncSlice) appendSetTCPUserTimeoutFunc(userTimeoutMsecs int) setFun func (fns setFuncSlice) appendSetTransparentFunc(transparent bool) setFuncSlice { if transparent { - return append(fns, setTransparent) + return append(fns, func(fd int, network string, _ *SocketInfo) error { + return setTransparent(fd, network) + }) } return fns } func (fns setFuncSlice) appendSetRecvOrigDstAddrFunc(recvOrigDstAddr bool) setFuncSlice { if recvOrigDstAddr { - return append(fns, setRecvOrigDstAddr) + return append(fns, func(fd int, network string, _ *SocketInfo) error { + return setRecvOrigDstAddr(fd, network) + }) } return fns } diff --git a/conn/conn_mmsg.go b/conn/conn_mmsg.go index 526e59a..11f6d34 100644 --- a/conn/conn_mmsg.go +++ b/conn/conn_mmsg.go @@ -8,10 +8,13 @@ import ( ) // ListenUDPRawConn is like [ListenUDP] but wraps the [*net.UDPConn] in a [rawUDPConn] for batch I/O. -func (lc *ListenConfig) ListenUDPRawConn(ctx context.Context, network, address string) (rawUDPConn, error) { - pc, err := lc.ListenConfig.ListenPacket(ctx, network, address) +func (lc *ListenConfig) ListenUDPRawConn(ctx context.Context, network, address string) (c rawUDPConn, info SocketInfo, err error) { + nlc := lc.tlc.ListenConfig + nlc.Control = lc.fns.controlFunc(&info) + pc, err := nlc.ListenPacket(ctx, network, address) if err != nil { - return rawUDPConn{}, err + return rawUDPConn{}, info, err } - return NewRawUDPConn(pc.(*net.UDPConn)) + c, err = NewRawUDPConn(pc.(*net.UDPConn)) + return c, info, err } diff --git a/conn/conn_pmtud.go b/conn/conn_pmtud.go index eb1b926..866f37a 100644 --- a/conn/conn_pmtud.go +++ b/conn/conn_pmtud.go @@ -4,7 +4,9 @@ package conn func (fns setFuncSlice) appendSetPMTUDFunc(pmtud bool) setFuncSlice { if pmtud { - return append(fns, setPMTUD) + return append(fns, func(fd int, network string, _ *SocketInfo) error { + return setPMTUD(fd, network) + }) } return fns } diff --git a/conn/conn_reuseport.go b/conn/conn_reuseport.go index 5e9a7c0..d52113b 100644 --- a/conn/conn_reuseport.go +++ b/conn/conn_reuseport.go @@ -17,7 +17,7 @@ func setReusePort(fd int) error { func (fns setFuncSlice) appendSetReusePortFunc(reusePort bool) setFuncSlice { if reusePort { - return append(fns, func(fd int, network string) error { + return append(fns, func(fd int, network string, _ *SocketInfo) error { return setReusePort(fd) }) } diff --git a/conn/conn_udpgro.go b/conn/conn_udpgro.go index d185b3c..646c4b0 100644 --- a/conn/conn_udpgro.go +++ b/conn/conn_udpgro.go @@ -4,8 +4,8 @@ package conn func (fns setFuncSlice) appendSetUDPGenericReceiveOffloadFunc(gro bool) setFuncSlice { if gro { - return append(fns, func(fd int, _ string) error { - setUDPGenericReceiveOffload(fd) + return append(fns, func(fd int, _ string, info *SocketInfo) error { + setUDPGenericReceiveOffload(fd, info) return nil }) } diff --git a/conn/conn_windows.go b/conn/conn_windows.go index 923e8ea..3427456 100644 --- a/conn/conn_windows.go +++ b/conn/conn_windows.go @@ -53,9 +53,11 @@ func setPMTUD(fd int, network string) error { return nil } -func setUDPGenericReceiveOffload(fd int) { +func setUDPGenericReceiveOffload(fd int, info *SocketInfo) { // Both quinn and msquic set this to 65535. - _ = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_UDP, windows.UDP_RECV_MAX_COALESCED_SIZE, 65535) + if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_UDP, windows.UDP_RECV_MAX_COALESCED_SIZE, 65535); err == nil { + info.UDPGenericReceiveOffload = true + } } func setRecvPktinfo(fd int, network string) error { diff --git a/dns/dns.go b/dns/dns.go index e46e13e..20dee03 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -311,7 +311,7 @@ func (r *Resolver) sendQueriesUDP(ctx context.Context, nameString string, q4Pkt, } defer clientSession.Close() - udpConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") + udpConn, _, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") if err != nil { r.logger.Warn("Failed to create UDP socket for DNS lookup", zap.String("resolver", r.name), diff --git a/service/tcp.go b/service/tcp.go index 86aae37..9514ce2 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -88,7 +88,7 @@ func (s *TCPRelay) Start(ctx context.Context) error { index := i lnc := &s.listeners[index] - l, err := lnc.listenConfig.ListenTCP(ctx, lnc.network, lnc.address) + l, _, err := lnc.listenConfig.ListenTCP(ctx, lnc.network, lnc.address) if err != nil { return err } diff --git a/service/udp_nat.go b/service/udp_nat.go index 748d0bf..20db25f 100644 --- a/service/udp_nat.go +++ b/service/udp_nat.go @@ -139,7 +139,7 @@ func (s *UDPNATRelay) Start(ctx context.Context) error { } func (s *UDPNATRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) { - lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address) + lnc.serverConn, _, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address) if err != nil { return } @@ -320,7 +320,7 @@ func (s *UDPNATRelay) recvFromServerConnGeneric(ctx context.Context, lnc *udpRel return } - natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", clientAddrPort), diff --git a/service/udp_nat_mmsg.go b/service/udp_nat_mmsg.go index 0620ffa..e13258e 100644 --- a/service/udp_nat_mmsg.go +++ b/service/udp_nat_mmsg.go @@ -56,7 +56,7 @@ func (s *UDPNATRelay) start(ctx context.Context, index int, lnc *udpRelayServerC } func (s *UDPNATRelay) startMmsg(ctx context.Context, index int, lnc *udpRelayServerConn) error { - serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) + serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) if err != nil { return err } @@ -275,7 +275,7 @@ func (s *UDPNATRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRe return } - natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", clientAddrPort), diff --git a/service/udp_session.go b/service/udp_session.go index 5763126..a168e8e 100644 --- a/service/udp_session.go +++ b/service/udp_session.go @@ -150,7 +150,7 @@ func (s *UDPSessionRelay) Start(ctx context.Context) error { } func (s *UDPSessionRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) { - lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address) + lnc.serverConn, _, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address) if err != nil { return } @@ -372,7 +372,7 @@ func (s *UDPSessionRelay) recvFromServerConnGeneric(ctx context.Context, lnc *ud return } - natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), diff --git a/service/udp_session_mmsg.go b/service/udp_session_mmsg.go index b6829aa..85b6eb3 100644 --- a/service/udp_session_mmsg.go +++ b/service/udp_session_mmsg.go @@ -58,7 +58,7 @@ func (s *UDPSessionRelay) start(ctx context.Context, index int, lnc *udpRelaySer } func (s *UDPSessionRelay) startMmsg(ctx context.Context, index int, lnc *udpRelayServerConn) error { - serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) + serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) if err != nil { return err } @@ -316,7 +316,7 @@ func (s *UDPSessionRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *u return } - natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", &queuedPacket.clientAddrPort), diff --git a/service/udp_transparent_linux.go b/service/udp_transparent_linux.go index 4240882..ab9cf3b 100644 --- a/service/udp_transparent_linux.go +++ b/service/udp_transparent_linux.go @@ -128,7 +128,7 @@ func (s *UDPTransparentRelay) Start(ctx context.Context) error { index := i lnc := &s.listeners[index] - serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) + serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address) if err != nil { return err } @@ -297,7 +297,7 @@ func (s *UDPTransparentRelay) recvFromServerConnRecvmmsg(ctx context.Context, ln return } - natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") + natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "") if err != nil { lnc.logger.Warn("Failed to create UDP socket for new NAT session", zap.Stringer("clientAddress", clientAddrPort), @@ -546,7 +546,7 @@ type transparentConn struct { } func (s *UDPTransparentRelay) newTransparentConn(ctx context.Context, address string, relayBatchSize int, name *byte, namelen uint32) (*transparentConn, error) { - c, err := s.transparentConnListenConfig.ListenUDPRawConn(ctx, "udp", address) + c, _, err := s.transparentConnListenConfig.ListenUDPRawConn(ctx, "udp", address) if err != nil { return nil, err }