Skip to content

Commit

Permalink
udpnat2: Reflect handler
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 19, 2024
1 parent 30fbafd commit 54b69f6
Showing 4 changed files with 109 additions and 108 deletions.
147 changes: 60 additions & 87 deletions common/bufio/copy.go
Original file line number Diff line number Diff line change
@@ -30,27 +30,38 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
if !cachedBuffer.IsEmpty() {
dataLen := cachedBuffer.Len()
for _, counter := range readCounters {
counter(int64(dataLen))
}
_, err = destination.Write(cachedBuffer.Bytes())
if err != nil {
cachedBuffer.Release()
return
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
cachedBuffer.Release()
continue
}
}
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
break
}
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
}

func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
if srcIsSyscall && dstIsSyscall {
var handled bool
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
if handled {
return
}
}
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
}

@@ -75,6 +86,7 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
}

// Deprecated: not used
func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, source N.ExtendedReader, buffer *buf.Buffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
buffer.IncRef()
defer buffer.DecRef()
@@ -113,19 +125,10 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
}

func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destination)
rearHeadroom := N.CalculateRearHeadroom(destination)
bufferSize := N.CalculateMTU(source, destination)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.BufferSize
}
options := N.NewReadWaitOptions(source, destination)
var notFirstTime bool
for {
buffer := buf.NewSize(bufferSize)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
buffer := options.NewBuffer()
err = source.ReadBuffer(buffer)
if err != nil {
buffer.Release()
@@ -136,7 +139,10 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
for _, counter := range readCounters {
counter(int64(dataLen))
}
options.PostReturn(buffer)
err = destination.WriteBuffer(buffer)
if err != nil {
buffer.Leak()
@@ -146,9 +152,6 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
@@ -196,18 +199,6 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
return group.Run(ctx)
}

// Deprecated: not used
func CopyConnContextList(contextList []context.Context, source net.Conn, destination net.Conn) error {
switch len(contextList) {
case 0:
return CopyConn(context.Background(), source, destination)
case 1:
return CopyConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
}

func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64, err error) {
var readCounters, writeCounters []N.CountFunc
var cachedPackets []*N.PacketBuffer
@@ -225,24 +216,24 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
break
}
if cachedPackets != nil {
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets)
n, err = WritePacketWithPool(originSource, destinationConn, cachedPackets, readCounters, writeCounters)
if err != nil {
return
}
}
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
copeN, err := CopyPacketWithCounters(destinationConn, source, originSource, readCounters, writeCounters)
n += copeN
return
}

func CopyPacketWithCounters(destinationConn N.PacketWriter, source N.PacketReader, originSource N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
var (
handled bool
copeN int64
)
readWaiter, isReadWaiter := CreatePacketReadWaiter(source)
if isReadWaiter {
needCopy := readWaiter.InitializeReadWaiter(N.ReadWaitOptions{
FrontHeadroom: frontHeadroom,
RearHeadroom: rearHeadroom,
MTU: N.CalculateMTU(source, destinationConn),
})
needCopy := readWaiter.InitializeReadWaiter(N.NewReadWaitOptions(source, destinationConn))
if !needCopy || common.LowMemory {
handled, copeN, err = copyPacketWaitWithPool(originSource, destinationConn, readWaiter, readCounters, writeCounters, n > 0)
if handled {
@@ -256,71 +247,65 @@ func CopyPacket(destinationConn N.PacketWriter, source N.PacketReader) (n int64,
return
}

func CopyPacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
bufferSize := N.CalculateMTU(source, destinationConn)
if bufferSize > 0 {
bufferSize += frontHeadroom + rearHeadroom
} else {
bufferSize = buf.UDPBufferSize
}
var destination M.Socksaddr
func CopyPacketWithPool(originSource N.PacketReader, destination N.PacketWriter, source N.PacketReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, notFirstTime bool) (n int64, err error) {
options := N.NewReadWaitOptions(source, destination)
var destinationAddress M.Socksaddr
for {
buffer := buf.NewSize(bufferSize)
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
destination, err = source.ReadPacket(buffer)
buffer := options.NewPacketBuffer()
destinationAddress, err = source.ReadPacket(buffer)
if err != nil {
buffer.Release()
return
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, destination)
for _, counter := range readCounters {
counter(int64(dataLen))
}
options.PostReturn(buffer)
err = destination.WritePacket(buffer, destinationAddress)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
n += int64(dataLen)
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
}

func WritePacketWithPool(originSource N.PacketReader, destinationConn N.PacketWriter, packetBuffers []*N.PacketBuffer) (n int64, err error) {
frontHeadroom := N.CalculateFrontHeadroom(destinationConn)
rearHeadroom := N.CalculateRearHeadroom(destinationConn)
func WritePacketWithPool(originSource N.PacketReader, destination N.PacketWriter, packetBuffers []*N.PacketBuffer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
options := N.NewReadWaitOptions(nil, destination)
var notFirstTime bool
for _, packetBuffer := range packetBuffers {
buffer := buf.NewPacket()
buffer.Resize(frontHeadroom, 0)
buffer.Reserve(rearHeadroom)
_, err = buffer.Write(packetBuffer.Buffer.Bytes())
packetBuffer.Buffer.Release()
for _, counter := range readCounters {
counter(int64(packetBuffer.Buffer.Len()))
}
buffer := options.Copy(packetBuffer.Buffer)
_, err = buffer.Write(buffer.Bytes())
buffer.Release()
if err != nil {
buffer.Release()
continue
}
dataLen := buffer.Len()
buffer.OverCap(rearHeadroom)
err = destinationConn.WritePacket(buffer, packetBuffer.Destination)
options.PostReturn(buffer)
err = destination.WritePacket(buffer, packetBuffer.Destination)
if err != nil {
buffer.Leak()
if !notFirstTime {
err = N.ReportHandshakeFailure(originSource, err)
}
return
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
n += int64(dataLen)
notFirstTime = true
}
return
}
@@ -339,15 +324,3 @@ func CopyPacketConn(ctx context.Context, source N.PacketConn, destination N.Pack
group.FastFail()
return group.Run(ctx)
}

// Deprecated: not used
func CopyPacketConnContextList(contextList []context.Context, source N.PacketConn, destination N.PacketConn) error {
switch len(contextList) {
case 0:
return CopyPacketConn(context.Background(), source, destination)
case 1:
return CopyPacketConn(contextList[0], source, destination)
default:
panic("invalid context list")
}
}
8 changes: 8 additions & 0 deletions common/network/direct.go
Original file line number Diff line number Diff line change
@@ -15,6 +15,14 @@ type ReadWaitOptions struct {
MTU int
}

func NewReadWaitOptions(source any, destination any) ReadWaitOptions {
return ReadWaitOptions{
FrontHeadroom: CalculateFrontHeadroom(destination),
RearHeadroom: CalculateRearHeadroom(destination),
MTU: CalculateMTU(source, destination),
}
}

func (o ReadWaitOptions) NeedHeadroom() bool {
return o.FrontHeadroom > 0 || o.RearHeadroom > 0
}
48 changes: 34 additions & 14 deletions common/udpnat2/conn.go
Original file line number Diff line number Diff line change
@@ -7,22 +7,35 @@ import (
"time"

"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)

type Conn struct {
type Conn interface {
N.PacketConn
SetHandler(handler Handler)
}

type Handler interface {
N.UDPHandlerEx
Close() error
}

var _ Conn = (*natConn)(nil)

type natConn struct {
writer N.PacketWriter
localAddr M.Socksaddr
handler N.UDPHandlerEx
handler Handler
packetChan chan *N.PacketBuffer
doneChan chan struct{}
readDeadline pipe.Deadline
readWaitOptions N.ReadWaitOptions
}

func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
select {
case p := <-c.packetChan:
_, err = buffer.ReadOnceFrom(p.Buffer)
@@ -37,12 +50,13 @@ func (c *Conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) {
}
}

func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return c.writer.WritePacket(buffer, destination)
}

func (c *Conn) SetHandler(handler N.UDPHandlerEx) {
func (c *natConn) SetHandler(handler Handler) {
c.handler = handler
c.readWaitOptions = N.NewReadWaitOptions(c.writer, handler)
fetch:
for {
select {
@@ -56,12 +70,12 @@ fetch:
}
}

func (c *Conn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) {
c.readWaitOptions = options
return false
}

func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) {
select {
case packet := <-c.packetChan:
buffer = c.readWaitOptions.Copy(packet.Buffer)
@@ -75,36 +89,42 @@ func (c *Conn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, er
}
}

func (c *Conn) Close() error {
func (c *natConn) Close() error {
select {
case <-c.doneChan:
default:
close(c.doneChan)
if c.handler != nil {
err := c.handler.Close()
if err != nil {
return E.Cause(err, "close handler")
}
}
}
return nil
}

func (c *Conn) LocalAddr() net.Addr {
func (c *natConn) LocalAddr() net.Addr {
return c.localAddr
}

func (c *Conn) RemoteAddr() net.Addr {
func (c *natConn) RemoteAddr() net.Addr {
return M.Socksaddr{}
}

func (c *Conn) SetDeadline(t time.Time) error {
func (c *natConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}

func (c *Conn) SetReadDeadline(t time.Time) error {
func (c *natConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}

func (c *Conn) SetWriteDeadline(t time.Time) error {
func (c *natConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}

func (c *Conn) Upstream() any {
func (c *natConn) Upstream() any {
return c.writer
}
Loading

0 comments on commit 54b69f6

Please sign in to comment.