Skip to content

Commit

Permalink
🍮 conn: rewrite socket cmsg handling
Browse files Browse the repository at this point in the history
Add support for UDP GRO and GSO on Linux and Windows.
  • Loading branch information
database64128 committed Sep 26, 2024
1 parent 5dab961 commit f9866fd
Show file tree
Hide file tree
Showing 12 changed files with 401 additions and 127 deletions.
28 changes: 28 additions & 0 deletions conn/cmsg.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package conn

import "net/netip"

// SocketControlMessageBufferSize specifies the buffer size for receiving socket control messages.
const SocketControlMessageBufferSize = socketControlMessageBufferSize

// SocketControlMessage contains information that can be parsed from or put into socket control messages.
type SocketControlMessage struct {
// PktinfoAddr is the IP address of the network interface the packet was received from.
PktinfoAddr netip.Addr

// PktinfoIfindex is the index of the network interface the packet was received from.
PktinfoIfindex uint32

// SegmentSize is the UDP GRO/GSO segment size.
SegmentSize uint32
}

// ParseSocketControlMessage parses a sequence of socket control messages and returns the parsed information.
func ParseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) {
return parseSocketControlMessage(cmsg)
}

// AppendTo appends the socket control message to the buffer.
func (m SocketControlMessage) AppendTo(b []byte) []byte {
return m.appendTo(b)
}
93 changes: 93 additions & 0 deletions conn/cmsg_darwin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package conn

import (
"fmt"
"net/netip"
"unsafe"

"github.com/database64128/shadowsocks-go/slicehelper"
"golang.org/x/sys/unix"
)

const socketControlMessageBufferSize = unix.SizeofCmsghdr + alignedSizeofInet6Pktinfo

const cmsgAlignTo = 4

func cmsgAlign(n uint32) uint32 {
return (n + cmsgAlignTo - 1) & ^uint32(cmsgAlignTo-1)
}

func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) {
for len(cmsg) >= unix.SizeofCmsghdr {
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg)))
msgSize := cmsgAlign(cmsghdr.Len)
if cmsghdr.Len < unix.SizeofCmsghdr || int(msgSize) > len(cmsg) {
return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len)
}

switch {
case cmsghdr.Level == unix.IPPROTO_IP && cmsghdr.Type == unix.IP_PKTINFO:
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet4Pktinfo {
return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len)
}
var pktinfo unix.Inet4Pktinfo
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo), cmsg[unix.SizeofCmsghdr:])
m.PktinfoAddr = netip.AddrFrom4(pktinfo.Spec_dst)
m.PktinfoIfindex = pktinfo.Ifindex

case cmsghdr.Level == unix.IPPROTO_IPV6 && cmsghdr.Type == unix.IPV6_PKTINFO:
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet6Pktinfo {
return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len)
}
var pktinfo unix.Inet6Pktinfo
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo), cmsg[unix.SizeofCmsghdr:])
m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr)
m.PktinfoIfindex = pktinfo.Ifindex
}

cmsg = cmsg[msgSize:]
}

return m, nil
}

const (
alignedSizeofInet4Pktinfo = (unix.SizeofInet4Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1)
alignedSizeofInet6Pktinfo = (unix.SizeofInet6Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1)
)

func (m SocketControlMessage) appendTo(b []byte) []byte {
switch {
case m.PktinfoAddr.Is4():
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet4Pktinfo)
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = unix.Cmsghdr{
Len: unix.SizeofCmsghdr + unix.SizeofInet4Pktinfo,
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
}
pktinfo := unix.Inet4Pktinfo{
Ifindex: m.PktinfoIfindex,
Spec_dst: m.PktinfoAddr.As4(),
}
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo))

case m.PktinfoAddr.Is6():
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet6Pktinfo)
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = unix.Cmsghdr{
Len: unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo,
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
}
pktinfo := unix.Inet6Pktinfo{
Addr: m.PktinfoAddr.As16(),
Ifindex: m.PktinfoIfindex,
}
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo))
}

return b
}
116 changes: 116 additions & 0 deletions conn/cmsg_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package conn

import (
"fmt"
"net/netip"
"unsafe"

"github.com/database64128/shadowsocks-go/slicehelper"
"golang.org/x/sys/unix"
)

const socketControlMessageBufferSize = unix.SizeofCmsghdr + alignedSizeofInet6Pktinfo +
unix.SizeofCmsghdr + alignedSizeofGROSegmentSize

const sizeofGROSegmentSize = int(unsafe.Sizeof(uint16(0)))

func cmsgAlign(n uint64) uint64 {
return (n + unix.SizeofPtr - 1) & ^uint64(unix.SizeofPtr-1)
}

func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) {
for len(cmsg) >= unix.SizeofCmsghdr {
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg)))
msgSize := cmsgAlign(cmsghdr.Len)
if cmsghdr.Len < unix.SizeofCmsghdr || int(msgSize) > len(cmsg) {
return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len)
}

switch {
case cmsghdr.Level == unix.IPPROTO_IP && cmsghdr.Type == unix.IP_PKTINFO:
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet4Pktinfo {
return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len)
}
var pktinfo unix.Inet4Pktinfo
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo), cmsg[unix.SizeofCmsghdr:])
m.PktinfoAddr = netip.AddrFrom4(pktinfo.Spec_dst)
m.PktinfoIfindex = uint32(pktinfo.Ifindex)

case cmsghdr.Level == unix.IPPROTO_IPV6 && cmsghdr.Type == unix.IPV6_PKTINFO:
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet6Pktinfo {
return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len)
}
var pktinfo unix.Inet6Pktinfo
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo), cmsg[unix.SizeofCmsghdr:])
m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr)
m.PktinfoIfindex = pktinfo.Ifindex

case cmsghdr.Level == unix.IPPROTO_UDP && cmsghdr.Type == unix.UDP_GRO:
if len(cmsg) < unix.SizeofCmsghdr+sizeofGROSegmentSize {
return m, fmt.Errorf("invalid UDP_GRO control message length %d", cmsghdr.Len)
}
var segmentSize uint16
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&segmentSize)), sizeofGROSegmentSize), cmsg[unix.SizeofCmsghdr:])
m.SegmentSize = uint32(segmentSize)
}

cmsg = cmsg[msgSize:]
}

return m, nil
}

const (
alignedSizeofInet4Pktinfo = (unix.SizeofInet4Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1)
alignedSizeofInet6Pktinfo = (unix.SizeofInet6Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1)
alignedSizeofGROSegmentSize = (sizeofGROSegmentSize + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1)
)

func (m SocketControlMessage) appendTo(b []byte) []byte {
switch {
case m.PktinfoAddr.Is4():
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet4Pktinfo)
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = unix.Cmsghdr{
Len: unix.SizeofCmsghdr + unix.SizeofInet4Pktinfo,
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
}
pktinfo := unix.Inet4Pktinfo{
Ifindex: int32(m.PktinfoIfindex),
Spec_dst: m.PktinfoAddr.As4(),
}
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo))

case m.PktinfoAddr.Is6():
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet6Pktinfo)
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = unix.Cmsghdr{
Len: unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo,
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
}
pktinfo := unix.Inet6Pktinfo{
Addr: m.PktinfoAddr.As16(),
Ifindex: m.PktinfoIfindex,
}
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo))
}

if m.SegmentSize > 0 {
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofGROSegmentSize)
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = unix.Cmsghdr{
Len: unix.SizeofCmsghdr + uint64(sizeofGROSegmentSize),
Level: unix.IPPROTO_UDP,
Type: unix.UDP_GRO,
}
segmentSize := uint16(m.SegmentSize)
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&segmentSize)), sizeofGROSegmentSize))
}

return b
}
13 changes: 13 additions & 0 deletions conn/cmsg_stub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//go:build !darwin && !linux && !windows

package conn

const socketControlMessageBufferSize = 0

func parseSocketControlMessage(_ []byte) (SocketControlMessage, error) {
return SocketControlMessage{}, nil
}

func (SocketControlMessage) appendTo(b []byte) []byte {
return b
}
139 changes: 139 additions & 0 deletions conn/cmsg_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package conn

import (
"fmt"
"net/netip"
"unsafe"

"github.com/database64128/shadowsocks-go/slicehelper"
"golang.org/x/sys/windows"
)

const socketControlMessageBufferSize = sizeofCmsghdr + alignedSizeofInet6Pktinfo +
sizeofCmsghdr + alignedSizeofCoalescedInfo

const (
sizeofPtr = int(unsafe.Sizeof(uintptr(0)))
sizeofCmsghdr = int(unsafe.Sizeof(Cmsghdr{}))
sizeofInet4Pktinfo = int(unsafe.Sizeof(Inet4Pktinfo{}))
sizeofInet6Pktinfo = int(unsafe.Sizeof(Inet6Pktinfo{}))
sizeofCoalescedInfo = int(unsafe.Sizeof(uint32(0)))
)

// Structure CMSGHDR from ws2def.h
type Cmsghdr struct {
Len uintptr
Level int32
Type int32
}

// Structure IN_PKTINFO from ws2ipdef.h
type Inet4Pktinfo struct {
Addr [4]byte
Ifindex uint32
}

// Structure IN6_PKTINFO from ws2ipdef.h
type Inet6Pktinfo struct {
Addr [16]byte
Ifindex uint32
}

func cmsgAlign(n int) int {
return (n + sizeofPtr - 1) & ^(sizeofPtr - 1)
}

func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) {
for len(cmsg) >= sizeofCmsghdr {
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg)))
msgLen := int(cmsghdr.Len)
msgSize := cmsgAlign(msgLen)
if msgLen < sizeofCmsghdr || msgSize > len(cmsg) {
return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len)
}

switch {
case cmsghdr.Level == windows.IPPROTO_IP && cmsghdr.Type == windows.IP_PKTINFO:
if len(cmsg) < sizeofCmsghdr+sizeofInet4Pktinfo {
return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len)
}
var pktinfo Inet4Pktinfo
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet4Pktinfo), cmsg[sizeofCmsghdr:])
m.PktinfoAddr = netip.AddrFrom4(pktinfo.Addr)
m.PktinfoIfindex = pktinfo.Ifindex

case cmsghdr.Level == windows.IPPROTO_IPV6 && cmsghdr.Type == windows.IPV6_PKTINFO:
if len(cmsg) < sizeofCmsghdr+sizeofInet6Pktinfo {
return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len)
}
var pktinfo Inet6Pktinfo
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet6Pktinfo), cmsg[sizeofCmsghdr:])
m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr)
m.PktinfoIfindex = pktinfo.Ifindex

case cmsghdr.Level == windows.IPPROTO_UDP && cmsghdr.Type == windows.UDP_COALESCED_INFO:
if len(cmsg) < sizeofCmsghdr+sizeofCoalescedInfo {
return m, fmt.Errorf("invalid UDP_COALESCED_INFO control message length %d", cmsghdr.Len)
}
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&m.SegmentSize)), sizeofCoalescedInfo), cmsg[sizeofCmsghdr:])
}

cmsg = cmsg[msgSize:]
}

return m, nil
}

const (
alignedSizeofInet4Pktinfo = (sizeofInet4Pktinfo + sizeofPtr - 1) & ^(sizeofPtr - 1)
alignedSizeofInet6Pktinfo = (sizeofInet6Pktinfo + sizeofPtr - 1) & ^(sizeofPtr - 1)
alignedSizeofCoalescedInfo = (sizeofCoalescedInfo + sizeofPtr - 1) & ^(sizeofPtr - 1)
)

func (m SocketControlMessage) appendTo(b []byte) []byte {
switch {
case m.PktinfoAddr.Is4():
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofInet4Pktinfo)
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = Cmsghdr{
Len: uintptr(sizeofCmsghdr + sizeofInet4Pktinfo),
Level: windows.IPPROTO_IP,
Type: windows.IP_PKTINFO,
}
pktinfo := Inet4Pktinfo{
Addr: m.PktinfoAddr.As4(),
Ifindex: m.PktinfoIfindex,
}
_ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet4Pktinfo))

case m.PktinfoAddr.Is6():
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofInet6Pktinfo)
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = Cmsghdr{
Len: uintptr(sizeofCmsghdr + sizeofInet6Pktinfo),
Level: windows.IPPROTO_IPV6,
Type: windows.IPV6_PKTINFO,
}
pktinfo := Inet6Pktinfo{
Addr: m.PktinfoAddr.As16(),
Ifindex: m.PktinfoIfindex,
}
_ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet6Pktinfo))
}

if m.SegmentSize > 0 {
var msgBuf []byte
b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofCoalescedInfo)
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf)))
*cmsghdr = Cmsghdr{
Len: uintptr(sizeofCmsghdr + sizeofCoalescedInfo),
Level: windows.IPPROTO_UDP,
Type: windows.UDP_COALESCED_INFO,
}
_ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&m.SegmentSize)), sizeofCoalescedInfo))
}

return b
}
Loading

0 comments on commit f9866fd

Please sign in to comment.