Skip to content

Commit

Permalink
🥧 conn: add SocketInfo
Browse files Browse the repository at this point in the history
SocketInfo allows propagating the result of applying the socket options.
  • Loading branch information
database64128 committed Sep 29, 2024
1 parent ae4d561 commit 0f8fcd9
Show file tree
Hide file tree
Showing 18 changed files with 92 additions and 57 deletions.
66 changes: 43 additions & 23 deletions conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,31 @@ import (
"github.com/database64128/tfo-go/v2"
)

type setFunc = func(fd int, network string) error
// SocketInfo contains information about a socket.
type SocketInfo struct {
// MaxUDPGSOSegments is the maximum number of UDP GSO segments supported by the socket.
//
// If UDP GSO is not enabled on the socket, or the system does not support UDP GSO, the value is 1.
//
// The value is 0 if the socket is not a UDP socket.
MaxUDPGSOSegments int

// UDPGenericReceiveOffload indicates whether UDP GRO is enabled on the socket.
UDPGenericReceiveOffload bool
}

type setFunc = func(fd int, network string, info *SocketInfo) error

type setFuncSlice []setFunc

func (fns setFuncSlice) controlContextFunc() func(ctx context.Context, network, address string, c syscall.RawConn) error {
func (fns setFuncSlice) controlContextFunc(info *SocketInfo) func(ctx context.Context, network, address string, c syscall.RawConn) error {
if len(fns) == 0 {
return nil
}
return func(ctx context.Context, network, address string, c syscall.RawConn) (err error) {
if cerr := c.Control(func(fd uintptr) {
for _, fn := range fns {
if err = fn(int(fd), network); err != nil {
if err = fn(int(fd), network, info); err != nil {
return
}
}
Expand All @@ -30,14 +43,14 @@ func (fns setFuncSlice) controlContextFunc() func(ctx context.Context, network,
}
}

func (fns setFuncSlice) controlFunc() func(network, address string, c syscall.RawConn) error {
func (fns setFuncSlice) controlFunc(info *SocketInfo) func(network, address string, c syscall.RawConn) error {
if len(fns) == 0 {
return nil
}
return func(network, address string, c syscall.RawConn) (err error) {
if cerr := c.Control(func(fd uintptr) {
for _, fn := range fns {
if err = fn(int(fd), network); err != nil {
if err = fn(int(fd), network, info); err != nil {
return
}
}
Expand All @@ -48,25 +61,32 @@ func (fns setFuncSlice) controlFunc() func(network, address string, c syscall.Ra
}
}

// ListenConfig is [tfo.ListenConfig] but provides a subjectively nicer API.
type ListenConfig tfo.ListenConfig
// ListenConfig wraps a [tfo.ListenConfig] and provides a subjectively nicer API.
type ListenConfig struct {
tlc tfo.ListenConfig
fns setFuncSlice
}

// ListenTCP wraps [tfo.ListenConfig.Listen] and returns a [*net.TCPListener] directly.
func (lc *ListenConfig) ListenTCP(ctx context.Context, network, address string) (*net.TCPListener, error) {
l, err := (*tfo.ListenConfig)(lc).Listen(ctx, network, address)
func (lc *ListenConfig) ListenTCP(ctx context.Context, network, address string) (tln *net.TCPListener, info SocketInfo, err error) {
tlc := lc.tlc
tlc.Control = lc.fns.controlFunc(&info)
ln, err := tlc.Listen(ctx, network, address)
if err != nil {
return nil, err
return nil, info, err
}
return l.(*net.TCPListener), nil
return ln.(*net.TCPListener), info, nil
}

// ListenUDP wraps [net.ListenConfig.ListenPacket] and returns a [*net.UDPConn] directly.
func (lc *ListenConfig) ListenUDP(ctx context.Context, network, address string) (*net.UDPConn, error) {
pc, err := lc.ListenConfig.ListenPacket(ctx, network, address)
func (lc *ListenConfig) ListenUDP(ctx context.Context, network, address string) (uc *net.UDPConn, info SocketInfo, err error) {
nlc := lc.tlc.ListenConfig
nlc.Control = lc.fns.controlFunc(&info)
pc, err := nlc.ListenPacket(ctx, network, address)
if err != nil {
return nil, err
return nil, info, err
}
return pc.(*net.UDPConn), nil
return pc.(*net.UDPConn), info, nil
}

// ListenerSocketOptions contains listener-specific socket options.
Expand Down Expand Up @@ -164,17 +184,17 @@ type ListenerSocketOptions struct {
ReceiveOriginalDestAddr bool
}

// ListenConfig returns a [ListenConfig] with a control function that sets the socket options.
// ListenConfig returns a [ListenConfig] that sets the socket options.
func (lso ListenerSocketOptions) ListenConfig() ListenConfig {
lc := ListenConfig{
ListenConfig: net.ListenConfig{
Control: lso.buildSetFns().controlFunc(),
tlc: tfo.ListenConfig{
Backlog: lso.TCPFastOpenBacklog,
DisableTFO: !lso.TCPFastOpen,
Fallback: lso.TCPFastOpenFallback,
},
Backlog: lso.TCPFastOpenBacklog,
DisableTFO: !lso.TCPFastOpen,
Fallback: lso.TCPFastOpenFallback,
fns: lso.buildSetFns(),
}
lc.SetMultipathTCP(lso.MultipathTCP)
lc.tlc.SetMultipathTCP(lso.MultipathTCP)
return lc
}

Expand Down Expand Up @@ -277,7 +297,7 @@ type DialerSocketOptions struct {
func (dso DialerSocketOptions) Dialer() Dialer {
d := Dialer{
Dialer: net.Dialer{
ControlContext: dso.buildSetFns().controlContextFunc(),
ControlContext: dso.buildSetFns().controlContextFunc(nil),
},
DisableTFO: !dso.TCPFastOpen,
Fallback: dso.TCPFastOpenFallback,
Expand Down
4 changes: 2 additions & 2 deletions conn/conn_bufsize_posix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package conn

func (fns setFuncSlice) appendSetSendBufferSize(size int) setFuncSlice {
if size > 0 {
return append(fns, func(fd int, _ string) error {
return append(fns, func(fd int, _ string, _ *SocketInfo) error {
return setSendBufferSize(fd, size)
})
}
Expand All @@ -13,7 +13,7 @@ func (fns setFuncSlice) appendSetSendBufferSize(size int) setFuncSlice {

func (fns setFuncSlice) appendSetRecvBufferSize(size int) setFuncSlice {
if size > 0 {
return append(fns, func(fd int, _ string) error {
return append(fns, func(fd int, _ string, _ *SocketInfo) error {
return setRecvBufferSize(fd, size)
})
}
Expand Down
4 changes: 3 additions & 1 deletion conn/conn_darwinlinuxwindows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package conn

func (fns setFuncSlice) appendSetRecvPktinfoFunc(recvPktinfo bool) setFuncSlice {
if recvPktinfo {
return append(fns, setRecvPktinfo)
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setRecvPktinfo(fd, network)
})
}
return fns
}
2 changes: 1 addition & 1 deletion conn/conn_flags_tclass.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func ParseFlagsForError(flags int) error {

func (fns setFuncSlice) appendSetTrafficClassFunc(trafficClass int) setFuncSlice {
if trafficClass != 0 {
return append(fns, func(fd int, network string) error {
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setTrafficClass(fd, network, trafficClass)
})
}
Expand Down
2 changes: 1 addition & 1 deletion conn/conn_freebsdlinux.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package conn

func (fns setFuncSlice) appendSetFwmarkFunc(fwmark int) setFuncSlice {
if fwmark != 0 {
return append(fns, func(fd int, network string) error {
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setFwmark(fd, fwmark)
})
}
Expand Down
18 changes: 12 additions & 6 deletions conn/conn_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ func setTCPUserTimeout(fd, msecs int) error {
return nil
}

func setUDPGenericReceiveOffload(fd int) {
_ = unix.SetsockoptInt(fd, unix.IPPROTO_UDP, unix.UDP_GRO, 1)
func setUDPGenericReceiveOffload(fd int, info *SocketInfo) {
if err := unix.SetsockoptInt(fd, unix.IPPROTO_UDP, unix.UDP_GRO, 1); err == nil {
info.UDPGenericReceiveOffload = true
}
}

func setTransparent(fd int, network string) error {
Expand Down Expand Up @@ -143,7 +145,7 @@ func setRecvOrigDstAddr(fd int, network string) error {

func (fns setFuncSlice) appendSetTCPDeferAcceptFunc(deferAcceptSecs int) setFuncSlice {
if deferAcceptSecs > 0 {
return append(fns, func(fd int, network string) error {
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setTCPDeferAccept(fd, deferAcceptSecs)
})
}
Expand All @@ -152,7 +154,7 @@ func (fns setFuncSlice) appendSetTCPDeferAcceptFunc(deferAcceptSecs int) setFunc

func (fns setFuncSlice) appendSetTCPUserTimeoutFunc(userTimeoutMsecs int) setFuncSlice {
if userTimeoutMsecs > 0 {
return append(fns, func(fd int, network string) error {
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setTCPUserTimeout(fd, userTimeoutMsecs)
})
}
Expand All @@ -161,14 +163,18 @@ func (fns setFuncSlice) appendSetTCPUserTimeoutFunc(userTimeoutMsecs int) setFun

func (fns setFuncSlice) appendSetTransparentFunc(transparent bool) setFuncSlice {
if transparent {
return append(fns, setTransparent)
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setTransparent(fd, network)
})
}
return fns
}

func (fns setFuncSlice) appendSetRecvOrigDstAddrFunc(recvOrigDstAddr bool) setFuncSlice {
if recvOrigDstAddr {
return append(fns, setRecvOrigDstAddr)
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setRecvOrigDstAddr(fd, network)
})
}
return fns
}
Expand Down
11 changes: 7 additions & 4 deletions conn/conn_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
)

// ListenUDPRawConn is like [ListenUDP] but wraps the [*net.UDPConn] in a [rawUDPConn] for batch I/O.
func (lc *ListenConfig) ListenUDPRawConn(ctx context.Context, network, address string) (rawUDPConn, error) {
pc, err := lc.ListenConfig.ListenPacket(ctx, network, address)
func (lc *ListenConfig) ListenUDPRawConn(ctx context.Context, network, address string) (c rawUDPConn, info SocketInfo, err error) {
nlc := lc.tlc.ListenConfig
nlc.Control = lc.fns.controlFunc(&info)
pc, err := nlc.ListenPacket(ctx, network, address)
if err != nil {
return rawUDPConn{}, err
return rawUDPConn{}, info, err
}
return NewRawUDPConn(pc.(*net.UDPConn))
c, err = NewRawUDPConn(pc.(*net.UDPConn))
return c, info, err
}
4 changes: 3 additions & 1 deletion conn/conn_pmtud.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package conn

func (fns setFuncSlice) appendSetPMTUDFunc(pmtud bool) setFuncSlice {
if pmtud {
return append(fns, setPMTUD)
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setPMTUD(fd, network)
})
}
return fns
}
2 changes: 1 addition & 1 deletion conn/conn_reuseport.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func setReusePort(fd int) error {

func (fns setFuncSlice) appendSetReusePortFunc(reusePort bool) setFuncSlice {
if reusePort {
return append(fns, func(fd int, network string) error {
return append(fns, func(fd int, network string, _ *SocketInfo) error {
return setReusePort(fd)
})
}
Expand Down
4 changes: 2 additions & 2 deletions conn/conn_udpgro.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package conn

func (fns setFuncSlice) appendSetUDPGenericReceiveOffloadFunc(gro bool) setFuncSlice {
if gro {
return append(fns, func(fd int, _ string) error {
setUDPGenericReceiveOffload(fd)
return append(fns, func(fd int, _ string, info *SocketInfo) error {
setUDPGenericReceiveOffload(fd, info)
return nil
})
}
Expand Down
6 changes: 4 additions & 2 deletions conn/conn_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@ func setPMTUD(fd int, network string) error {
return nil
}

func setUDPGenericReceiveOffload(fd int) {
func setUDPGenericReceiveOffload(fd int, info *SocketInfo) {
// Both quinn and msquic set this to 65535.
_ = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_UDP, windows.UDP_RECV_MAX_COALESCED_SIZE, 65535)
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_UDP, windows.UDP_RECV_MAX_COALESCED_SIZE, 65535); err == nil {
info.UDPGenericReceiveOffload = true
}
}

func setRecvPktinfo(fd int, network string) error {
Expand Down
2 changes: 1 addition & 1 deletion dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (r *Resolver) sendQueriesUDP(ctx context.Context, nameString string, q4Pkt,
}
defer clientSession.Close()

udpConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
udpConn, _, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
if err != nil {
r.logger.Warn("Failed to create UDP socket for DNS lookup",
zap.String("resolver", r.name),
Expand Down
2 changes: 1 addition & 1 deletion service/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (s *TCPRelay) Start(ctx context.Context) error {
index := i
lnc := &s.listeners[index]

l, err := lnc.listenConfig.ListenTCP(ctx, lnc.network, lnc.address)
l, _, err := lnc.listenConfig.ListenTCP(ctx, lnc.network, lnc.address)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions service/udp_nat.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (s *UDPNATRelay) Start(ctx context.Context) error {
}

func (s *UDPNATRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) {
lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address)
lnc.serverConn, _, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address)
if err != nil {
return
}
Expand Down Expand Up @@ -320,7 +320,7 @@ func (s *UDPNATRelay) recvFromServerConnGeneric(ctx context.Context, lnc *udpRel
return
}

natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
natConn, _, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
if err != nil {
lnc.logger.Warn("Failed to create UDP socket for new NAT session",
zap.Stringer("clientAddress", clientAddrPort),
Expand Down
4 changes: 2 additions & 2 deletions service/udp_nat_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *UDPNATRelay) start(ctx context.Context, index int, lnc *udpRelayServerC
}

func (s *UDPNATRelay) startMmsg(ctx context.Context, index int, lnc *udpRelayServerConn) error {
serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address)
serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address)
if err != nil {
return err
}
Expand Down Expand Up @@ -275,7 +275,7 @@ func (s *UDPNATRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *udpRe
return
}

natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "")
natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "")
if err != nil {
lnc.logger.Warn("Failed to create UDP socket for new NAT session",
zap.Stringer("clientAddress", clientAddrPort),
Expand Down
4 changes: 2 additions & 2 deletions service/udp_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (s *UDPSessionRelay) Start(ctx context.Context) error {
}

func (s *UDPSessionRelay) startGeneric(ctx context.Context, index int, lnc *udpRelayServerConn) (err error) {
lnc.serverConn, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address)
lnc.serverConn, _, err = lnc.listenConfig.ListenUDP(ctx, lnc.network, lnc.address)
if err != nil {
return
}
Expand Down Expand Up @@ -372,7 +372,7 @@ func (s *UDPSessionRelay) recvFromServerConnGeneric(ctx context.Context, lnc *ud
return
}

natConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
natConn, _, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
if err != nil {
lnc.logger.Warn("Failed to create UDP socket for new NAT session",
zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
Expand Down
4 changes: 2 additions & 2 deletions service/udp_session_mmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *UDPSessionRelay) start(ctx context.Context, index int, lnc *udpRelaySer
}

func (s *UDPSessionRelay) startMmsg(ctx context.Context, index int, lnc *udpRelayServerConn) error {
serverConn, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address)
serverConn, _, err := lnc.listenConfig.ListenUDPRawConn(ctx, lnc.network, lnc.address)
if err != nil {
return err
}
Expand Down Expand Up @@ -316,7 +316,7 @@ func (s *UDPSessionRelay) recvFromServerConnRecvmmsg(ctx context.Context, lnc *u
return
}

natConn, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "")
natConn, _, err := clientInfo.ListenConfig.ListenUDPRawConn(ctx, "udp", "")
if err != nil {
lnc.logger.Warn("Failed to create UDP socket for new NAT session",
zap.Stringer("clientAddress", &queuedPacket.clientAddrPort),
Expand Down
Loading

0 comments on commit 0f8fcd9

Please sign in to comment.